From a9142b942b0c1e4671f8d54e163b38740ad8a260 Mon Sep 17 00:00:00 2001 From: pattonjp Date: Fri, 8 Dec 2023 11:30:50 -0600 Subject: [PATCH] remove tei embeddings --- embeddings/tei/doc.go | 8 -- embeddings/tei/options.go | 104 ------------------ embeddings/tei/text_embeddings_inference..go | 85 -------------- embeddings/tei_test.go | 67 ----------- .../milvus_vectorstore_example.go | 8 +- llms/tei/doc.go | 8 -- llms/tei/options.go | 89 --------------- llms/tei/text_embeddings_inference..go | 43 -------- vectorstores/milvus/milvus_test.go | 18 +-- 9 files changed, 14 insertions(+), 416 deletions(-) delete mode 100644 embeddings/tei/doc.go delete mode 100644 embeddings/tei/options.go delete mode 100644 embeddings/tei/text_embeddings_inference..go delete mode 100644 embeddings/tei_test.go delete mode 100644 llms/tei/doc.go delete mode 100644 llms/tei/options.go delete mode 100644 llms/tei/text_embeddings_inference..go diff --git a/embeddings/tei/doc.go b/embeddings/tei/doc.go deleted file mode 100644 index 0a05401fc..000000000 --- a/embeddings/tei/doc.go +++ /dev/null @@ -1,8 +0,0 @@ -/* -Huggingface Text Embeddings Inference -https://github.com/huggingface/text-embeddings-inference - -package is a wrapper for the Huggingface text embeddings inference project -that can be run locally for creating vector embeddings. -*/ -package tei diff --git a/embeddings/tei/options.go b/embeddings/tei/options.go deleted file mode 100644 index c0601d9de..000000000 --- a/embeddings/tei/options.go +++ /dev/null @@ -1,104 +0,0 @@ -package tei - -import ( - "errors" - "runtime" - "time" - - client "github.com/gage-technologies/tei-go" -) - -const ( - _defaultBatchSize = 512 - _defaultStripNewLines = true - _defaultTimeNanoSeconds = 60 * 1000000000 -) - -var ErrMissingAPIBaseURL = errors.New("missing the API Base URL") //nolint:lll - -type Option func(emb *TextEmbeddingsInference) - -// WithStripNewLines is an option for specifying the should it strip new lines. -func WithStripNewLines(stripNewLines bool) Option { - return func(p *TextEmbeddingsInference) { - p.StripNewLines = stripNewLines - } -} - -// WithPoolSize is an option for specifying the number of goroutines. -func WithPoolSize(poolSize int) Option { - return func(p *TextEmbeddingsInference) { - p.poolSize = poolSize - } -} - -// WithBatchSize is an option for specifying the batch size. -func WithBatchSize(batchSize int) Option { - return func(p *TextEmbeddingsInference) { - p.BatchSize = batchSize - } -} - -// WithAPIBaseURL adds base url for api. -func WithAPIBaseURL(url string) Option { - return func(emb *TextEmbeddingsInference) { - emb.baseURL = url - } -} - -// WithHeaders add request headers. -func WithHeaders(headers map[string]string) Option { - return func(emb *TextEmbeddingsInference) { - if emb.headers == nil { - emb.headers = make(map[string]string, len(headers)) - } - for k, v := range headers { - emb.headers[k] = v - } - } -} - -// WithCookies add request cookies. -func WithCookies(cookies map[string]string) Option { - return func(emb *TextEmbeddingsInference) { - if emb.cookies == nil { - emb.cookies = make(map[string]string, len(cookies)) - } - for k, v := range cookies { - emb.cookies[k] = v - } - } -} - -// WithTimeout set the request timeout. -func WithTimeout(dur time.Duration) Option { - return func(emb *TextEmbeddingsInference) { - emb.timeout = dur - } -} - -// WithTruncate set the embedder to truncate input length. -func WithTruncate() Option { - return func(emb *TextEmbeddingsInference) { - emb.truncate = true - } -} - -func applyClientOptions(opts ...Option) (TextEmbeddingsInference, error) { - emb := TextEmbeddingsInference{ - StripNewLines: _defaultStripNewLines, - BatchSize: _defaultBatchSize, - timeout: time.Duration(_defaultTimeNanoSeconds), - poolSize: runtime.GOMAXPROCS(0), - } - for _, opt := range opts { - opt(&emb) - } - if emb.baseURL == "" { - return emb, ErrMissingAPIBaseURL - } - if emb.client == nil { - emb.client = client.NewClient(emb.baseURL, emb.headers, emb.cookies, emb.timeout) - } - return emb, nil -} diff --git a/embeddings/tei/text_embeddings_inference..go b/embeddings/tei/text_embeddings_inference..go deleted file mode 100644 index 1f50a8358..000000000 --- a/embeddings/tei/text_embeddings_inference..go +++ /dev/null @@ -1,85 +0,0 @@ -package tei - -import ( - "context" - "strings" - "time" - - client "github.com/gage-technologies/tei-go" - "github.com/sourcegraph/conc/pool" - "github.com/tmc/langchaingo/embeddings" -) - -type TextEmbeddingsInference struct { - client *client.Client - StripNewLines bool - truncate bool - BatchSize int - baseURL string - headers map[string]string - cookies map[string]string - timeout time.Duration - poolSize int -} - -var _ embeddings.Embedder = TextEmbeddingsInference{} - -func New(opts ...Option) (TextEmbeddingsInference, error) { - emb, err := applyClientOptions(opts...) - if err != nil { - return emb, err - } - emb.client = client.NewClient(emb.baseURL, emb.headers, emb.cookies, emb.timeout) - - return emb, nil -} - -// EmbedDocuments creates one vector embedding for each of the texts. -func (e TextEmbeddingsInference) EmbedDocuments(_ context.Context, texts []string) ([][]float32, error) { - batchedTexts := embeddings.BatchTexts( - embeddings.MaybeRemoveNewLines(texts, e.StripNewLines), - e.BatchSize, - ) - - emb := make([][]float32, 0, len(texts)) - - p := pool.New().WithMaxGoroutines(e.poolSize).WithErrors() - - for _, txt := range batchedTexts { - p.Go(func() error { - curTextEmbeddings, err := e.client.Embed(strings.Join(txt, " "), e.truncate) - if err != nil { - return err - } - - textLengths := make([]int, 0, len(txt)) - for _, text := range txt { - textLengths = append(textLengths, len(text)) - } - - combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) - if err != nil { - return err - } - - emb = append(emb, combined) - - return nil - }) - } - return emb, p.Wait() -} - -// EmbedQuery embeds a single text. -func (e TextEmbeddingsInference) EmbedQuery(_ context.Context, text string) ([]float32, error) { - if e.StripNewLines { - text = strings.ReplaceAll(text, "\n", " ") - } - - emb, err := e.client.Embed(text, false) - if err != nil { - return nil, err - } - - return emb[0], nil -} diff --git a/embeddings/tei_test.go b/embeddings/tei_test.go deleted file mode 100644 index 5760dd291..000000000 --- a/embeddings/tei_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package embeddings - -import ( - "context" - "os" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tmc/langchaingo/llms/tei" -) - -func newTEIEmbedder(t *testing.T, opts ...Option) *EmbedderImpl { - t.Helper() - teiURL := os.Getenv("TEI_API_URL") - if teiURL == "" { - t.Skip("TEI_API_URL not set") - return nil - } - - llm, err := tei.New( - tei.WithAPIBaseURL(teiURL), - tei.WithPoolSize(4), - ) - require.NoError(t, err) - embedder, err := NewEmbedder(llm, opts...) - require.NoError(t, err) - return embedder -} - -func TestTEIEmbeddings(t *testing.T) { - t.Parallel() - e := newTEIEmbedder(t) - texts := []string{"Hello world"} - emb, err := e.EmbedDocuments(context.Background(), texts) - require.NoError(t, err) - assert.Len(t, emb, 1) -} - -func TestTEIEmbeddingsQueryVsDocuments(t *testing.T) { - t.Parallel() - - e := newTEIEmbedder(t) - text := "hi there" - eq, err := e.EmbedQuery(context.Background(), text) - require.NoError(t, err) - - eb, err := e.EmbedDocuments(context.Background(), []string{text}) - require.NoError(t, err) - - // Using strict equality should be OK here because we expect the same values - // for the same string, deterministically. - assert.Equal(t, eq, eb[0]) -} - -func TestTEIEmbeddingsWithOptions(t *testing.T) { - t.Parallel() - - e := newTEIEmbedder(t, WithBatchSize(1), WithStripNewLines(false)) - - _, err := e.EmbedQuery(context.Background(), "Hello world!") - require.NoError(t, err) - - embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world"}) - require.NoError(t, err) - assert.Len(t, embeddings, 1) -} diff --git a/examples/huggingface-milvus-vectorstore-example/milvus_vectorstore_example.go b/examples/huggingface-milvus-vectorstore-example/milvus_vectorstore_example.go index ceb216e2a..b5781fc7e 100644 --- a/examples/huggingface-milvus-vectorstore-example/milvus_vectorstore_example.go +++ b/examples/huggingface-milvus-vectorstore-example/milvus_vectorstore_example.go @@ -9,7 +9,7 @@ import ( "github.com/milvus-io/milvus-sdk-go/v2/client" "github.com/milvus-io/milvus-sdk-go/v2/entity" "github.com/tmc/langchaingo/embeddings" - "github.com/tmc/langchaingo/llms/tei" + "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" @@ -18,6 +18,7 @@ import ( const ( poolSize = 5 + baseURL = "http://localhost:5500" ) func main() { @@ -29,10 +30,7 @@ func main() { } func newStore() (vectorstores.VectorStore, error) { - llm, err := tei.New( - tei.WithAPIBaseURL("http://localhost:5500"), - tei.WithPoolSize(poolSize), - ) + llm, err := openai.New(openai.WithBaseURL(baseURL)) if err != nil { log.Fatal(err) } diff --git a/llms/tei/doc.go b/llms/tei/doc.go deleted file mode 100644 index 0a05401fc..000000000 --- a/llms/tei/doc.go +++ /dev/null @@ -1,8 +0,0 @@ -/* -Huggingface Text Embeddings Inference -https://github.com/huggingface/text-embeddings-inference - -package is a wrapper for the Huggingface text embeddings inference project -that can be run locally for creating vector embeddings. -*/ -package tei diff --git a/llms/tei/options.go b/llms/tei/options.go deleted file mode 100644 index 836bb55ed..000000000 --- a/llms/tei/options.go +++ /dev/null @@ -1,89 +0,0 @@ -package tei - -import ( - "errors" - "os" - "runtime" - "time" - - client "github.com/gage-technologies/tei-go" -) - -const ( - defaultTimeNanoSeconds = 60 * 1000000000 - defaultURLEnvVarName = "TEI_API_URL" -) - -var ErrMissingAPIBaseURL = errors.New("missing the API Base URL") //nolint:lll - -type Option func(emb *TextEmbeddingsInference) - -// WithPoolSize is an option for specifying the number of goroutines. -func WithPoolSize(poolSize int) Option { - return func(p *TextEmbeddingsInference) { - p.poolSize = poolSize - } -} - -// WithAPIBaseURL adds base url for api. -func WithAPIBaseURL(url string) Option { - return func(emb *TextEmbeddingsInference) { - emb.baseURL = url - } -} - -// WithHeaders add request headers. -func WithHeaders(headers map[string]string) Option { - return func(emb *TextEmbeddingsInference) { - if emb.headers == nil { - emb.headers = make(map[string]string, len(headers)) - } - for k, v := range headers { - emb.headers[k] = v - } - } -} - -// WithCookies add request cookies. -func WithCookies(cookies map[string]string) Option { - return func(emb *TextEmbeddingsInference) { - if emb.cookies == nil { - emb.cookies = make(map[string]string, len(cookies)) - } - for k, v := range cookies { - emb.cookies[k] = v - } - } -} - -// WithTimeout set the request timeout. -func WithTimeout(dur time.Duration) Option { - return func(emb *TextEmbeddingsInference) { - emb.timeout = dur - } -} - -// WithTruncate set the embedder to truncate input length. -func WithTruncate() Option { - return func(emb *TextEmbeddingsInference) { - emb.truncate = true - } -} - -func applyClientOptions(opts ...Option) (TextEmbeddingsInference, error) { - emb := TextEmbeddingsInference{ - timeout: time.Duration(defaultTimeNanoSeconds), - poolSize: runtime.GOMAXPROCS(0), - baseURL: os.Getenv(defaultURLEnvVarName), - } - for _, opt := range opts { - opt(&emb) - } - if emb.baseURL == "" { - return emb, ErrMissingAPIBaseURL - } - if emb.client == nil { - emb.client = client.NewClient(emb.baseURL, emb.headers, emb.cookies, emb.timeout) - } - return emb, nil -} diff --git a/llms/tei/text_embeddings_inference..go b/llms/tei/text_embeddings_inference..go deleted file mode 100644 index edba867e5..000000000 --- a/llms/tei/text_embeddings_inference..go +++ /dev/null @@ -1,43 +0,0 @@ -package tei - -import ( - "context" - "time" - - client "github.com/gage-technologies/tei-go" - "github.com/sourcegraph/conc/pool" -) - -type TextEmbeddingsInference struct { - client *client.Client - truncate bool - baseURL string - headers map[string]string - cookies map[string]string - timeout time.Duration - poolSize int -} - -func New(opts ...Option) (TextEmbeddingsInference, error) { - emb, err := applyClientOptions(opts...) - if err != nil { - return emb, err - } - emb.client = client.NewClient(emb.baseURL, emb.headers, emb.cookies, emb.timeout) - - return emb, nil -} - -// CreateEmbedding creates one vector embedding for each of the texts. -func (e TextEmbeddingsInference) CreateEmbedding(_ context.Context, inputTexts []string) ([][]float32, error) { - p := pool.NewWithResults[[]float32](). - WithMaxGoroutines(e.poolSize). - WithErrors() - for _, txt := range inputTexts { - p.Go(func() ([]float32, error) { - res, err := e.client.Embed(txt, e.truncate) - return res[0], err - }) - } - return p.Wait() -} diff --git a/vectorstores/milvus/milvus_test.go b/vectorstores/milvus/milvus_test.go index b504e66e9..df6b33e34 100644 --- a/vectorstores/milvus/milvus_test.go +++ b/vectorstores/milvus/milvus_test.go @@ -9,20 +9,24 @@ import ( "github.com/milvus-io/milvus-sdk-go/v2/entity" "github.com/stretchr/testify/require" "github.com/tmc/langchaingo/embeddings" - "github.com/tmc/langchaingo/llms/tei" + "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" ) func getEmbedder(t *testing.T) (embeddings.Embedder, error) { t.Helper() - url := os.Getenv("TEI_URL") - if url == "" { - t.Skip("must set TEI_URL to run test") + + if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { + t.Skip("OPENAI_API_KEY not set") } - llm, err := tei.New( - tei.WithAPIBaseURL(url), - ) + url := os.Getenv("OPENAI_BASE_URL") + opts := []openai.Option{} + if url != "" { + opts = append(opts, openai.WithBaseURL(url)) + } + + llm, err := openai.New(opts...) require.NoError(t, err) return embeddings.NewEmbedder(llm) }