Skip to content

Commit

Permalink
Merge pull request #1057 from weaviate/mistral_text2vec
Browse files Browse the repository at this point in the history
Add support for mistral
  • Loading branch information
dirkkul authored Aug 29, 2024
2 parents c837f22 + a06d6bd commit 34e9911
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 6 deletions.
26 changes: 26 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ def test_basic_config():
}
},
),
(
Configure.Vectorizer.text2vec_mistral(
vectorize_collection_name=False,
model="cool-model",
),
{
"text2vec-mistral": {
"vectorizeClassName": False,
"model": "cool-model",
}
},
),
(
Configure.Vectorizer.text2vec_palm(
project_id="project",
Expand Down Expand Up @@ -1215,6 +1227,20 @@ def test_vector_config_flat_pq() -> None:
}
},
),
(
[Configure.NamedVectors.text2vec_mistral(name="test", source_properties=["prop"])],
{
"test": {
"vectorizer": {
"text2vec-mistral": {
"vectorizeClassName": True,
"properties": ["prop"],
}
},
"vectorIndexType": "hnsw",
}
},
),
(
[
Configure.NamedVectors.text2vec_palm(
Expand Down
5 changes: 5 additions & 0 deletions test/collection/test_vectorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from weaviate.collections.classes.config import Configure


def test_multi2vec_clip() -> None:
Configure.Vectorizer.multi2vec_clip(image_fields=["test"])
37 changes: 37 additions & 0 deletions weaviate/collections/classes/config_named_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_Text2VecGPT4AllConfigCreate,
_Text2VecHuggingFaceConfigCreate,
_Text2VecJinaConfigCreate,
_Text2VecMistralConfig,
_Text2VecOctoConfig,
_Text2VecOllamaConfig,
_Text2VecOpenAIConfigCreate,
Expand Down Expand Up @@ -225,6 +226,42 @@ def text2vec_contextionary(
vector_index_config=vector_index_config,
)

@staticmethod
def text2vec_mistral(
name: str,
*,
source_properties: Optional[List[str]] = None,
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
vectorize_collection_name: bool = True,
model: Optional[str] = None,
) -> _NamedVectorConfigCreate:
"""Create a named vector using the `text2vec-mistral` model.
See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-mistral)
for detailed usage.
Arguments:
`name`
The name of the named vector.
`source_properties`
Which properties should be included when vectorizing. By default all text properties are included.
`vector_index_config`
The configuration for Weaviate's vector index. Use wvc.config.Configure.VectorIndex to create a vector index configuration. None by default
`model`
The model to use. Defaults to `None`, which uses the server-defined default.
`vectorize_collection_name`
Whether to vectorize the collection name. Defaults to `True`.
"""
return _NamedVectorConfigCreate(
name=name,
source_properties=source_properties,
vectorizer=_Text2VecMistralConfig(
model=model,
vectorizeClassName=vectorize_collection_name,
),
vector_index_config=vector_index_config,
)

@staticmethod
def text2vec_octoai(
name: str,
Expand Down
28 changes: 28 additions & 0 deletions weaviate/collections/classes/config_vectorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Vectorizers(str, Enum):
TEXT2VEC_CONTEXTIONARY = "text2vec-contextionary"
TEXT2VEC_GPT4ALL = "text2vec-gpt4all"
TEXT2VEC_HUGGINGFACE = "text2vec-huggingface"
TEXT2VEC_MISTRAL = "text2vec-mistral"
TEXT2VEC_OCTOAI = "text2vec-octoai"
TEXT2VEC_OLLAMA = "text2vec-ollama"
TEXT2VEC_OPENAI = "text2vec-openai"
Expand Down Expand Up @@ -230,6 +231,14 @@ class _Text2VecHuggingFaceConfigCreate(_Text2VecHuggingFaceConfig, _VectorizerCo
pass


class _Text2VecMistralConfig(_VectorizerConfigCreate):
vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
default=Vectorizers.TEXT2VEC_MISTRAL, frozen=True, exclude=True
)
model: Optional[str]
vectorizeClassName: bool


OpenAIType = Literal["text", "code"]


Expand Down Expand Up @@ -805,6 +814,25 @@ def text2vec_huggingface(
vectorizeClassName=vectorize_collection_name,
)

@staticmethod
def text2vec_mistral(
*,
model: Optional[str] = None,
vectorize_collection_name: bool = True,
) -> _VectorizerConfigCreate:
"""Create a `_Text2VecMistralConfig` object for use when vectorizing using the `text2vec-mistral` model.
See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-mistral)
for detailed usage.
Arguments:
`model`
The model to use. Defaults to `None`, which uses the server-defined default.
`vectorize_collection_name`
Whether to vectorize the collection name. Defaults to `True`.
"""
return _Text2VecMistralConfig(model=model, vectorizeClassName=vectorize_collection_name)

@staticmethod
def text2vec_octoai(
*,
Expand Down
35 changes: 29 additions & 6 deletions weaviate/connect/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ class _IntegrationConfigJina(_IntegrationConfig):
base_url: Optional[str] = Field(serialization_alias="X-Jinaai-Baseurl")


class _IntegrationConfigMistral(_IntegrationConfig):
api_key: str = Field(serialization_alias="X-Mistral-Api-Key")
request_per_minute_embeddings: Optional[int] = Field(
serialization_alias="X-Mistral-Ratelimit-RequestPM-Embedding"
)
tokens_per_minute_embeddings: Optional[int] = Field(
serialization_alias="X-Mistral-Ratelimit-TokenPM-Embedding"
)


class _IntegrationConfigOcto(_IntegrationConfig):
api_key: str = Field(serialization_alias="X-OctoAI-Api-Key")
requests_per_minute_embeddings: Optional[int] = Field(
Expand All @@ -87,7 +97,7 @@ def cohere(
*,
api_key: str,
base_url: Optional[str] = None,
requests_per_minute_embeddings: Optional[int] = None
requests_per_minute_embeddings: Optional[int] = None,
) -> _IntegrationConfig:
return _IntegrationConfigCohere(
api_key=api_key,
Expand All @@ -100,7 +110,7 @@ def huggingface(
*,
api_key: str,
requests_per_minute_embeddings: Optional[int] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigHuggingface(
api_key=api_key,
Expand All @@ -115,7 +125,7 @@ def openai(
requests_per_minute_embeddings: Optional[int] = None,
tokens_per_minute_embeddings: Optional[int] = None,
organization: Optional[str] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigOpenAi(
api_key=api_key,
Expand Down Expand Up @@ -147,7 +157,7 @@ def voyageai(
api_key: str,
requests_per_minute_embeddings: Optional[int] = None,
tokens_per_minute_embeddings: Optional[int] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigVoyage(
api_key=api_key,
Expand All @@ -161,7 +171,7 @@ def jinaai(
*,
api_key: str,
requests_per_minute_embeddings: Optional[int] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigJina(
api_key=api_key,
Expand All @@ -174,10 +184,23 @@ def octoai(
*,
api_key: str,
requests_per_minute_embeddings: Optional[int] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigOcto(
api_key=api_key,
requests_per_minute_embeddings=requests_per_minute_embeddings,
base_url=base_url,
)

@staticmethod
def mistral(
*,
api_key: str,
request_per_minute_embeddings: Optional[int] = None,
tokens_per_minute_embeddings: Optional[int] = None,
) -> _IntegrationConfig:
return _IntegrationConfigMistral(
api_key=api_key,
request_per_minute_embeddings=request_per_minute_embeddings,
tokens_per_minute_embeddings=tokens_per_minute_embeddings,
)

0 comments on commit 34e9911

Please sign in to comment.