Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mistral #1057

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ def test_basic_config():
}
},
),
(
Configure.Vectorizer.text2vec_mistral(
vectorize_collection_name=False,
model="cool-model",
),
{
"text2vec-mistral": {
"vectorizeClassName": False,
"model": "cool-model",
}
},
),
(
Configure.Vectorizer.text2vec_palm(
project_id="project",
Expand Down Expand Up @@ -1185,6 +1197,20 @@ def test_vector_config_flat_pq() -> None:
}
},
),
(
[Configure.NamedVectors.text2vec_mistral(name="test", source_properties=["prop"])],
{
"test": {
"vectorizer": {
"text2vec-mistral": {
"vectorizeClassName": True,
"properties": ["prop"],
}
},
"vectorIndexType": "hnsw",
}
},
),
(
[
Configure.NamedVectors.text2vec_palm(
Expand Down
5 changes: 5 additions & 0 deletions test/collection/test_vectorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from weaviate.collections.classes.config import Configure


def test_multi2vec_clip() -> None:
Configure.Vectorizer.multi2vec_clip(image_fields=["test"])
37 changes: 37 additions & 0 deletions weaviate/collections/classes/config_named_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_Text2VecGPT4AllConfigCreate,
_Text2VecHuggingFaceConfigCreate,
_Text2VecJinaConfigCreate,
_Text2VecMistralConfig,
_Text2VecOctoConfig,
_Text2VecOllamaConfig,
_Text2VecOpenAIConfigCreate,
Expand Down Expand Up @@ -225,6 +226,42 @@ def text2vec_contextionary(
vector_index_config=vector_index_config,
)

@staticmethod
def text2vec_mistral(
name: str,
*,
source_properties: Optional[List[str]] = None,
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
vectorize_collection_name: bool = True,
model: Optional[str] = None,
) -> _NamedVectorConfigCreate:
"""Create a named vector using the `text2vec-mistral` model.

See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-mistral)
for detailed usage.

Arguments:
`name`
The name of the named vector.
`source_properties`
Which properties should be included when vectorizing. By default all text properties are included.
`vector_index_config`
The configuration for Weaviate's vector index. Use wvc.config.Configure.VectorIndex to create a vector index configuration. None by default
`model`
The model to use. Defaults to `None`, which uses the server-defined default.
`vectorize_collection_name`
Whether to vectorize the collection name. Defaults to `True`.
dirkkul marked this conversation as resolved.
Show resolved Hide resolved
"""
return _NamedVectorConfigCreate(
name=name,
source_properties=source_properties,
vectorizer=_Text2VecMistralConfig(
model=model,
vectorizeClassName=vectorize_collection_name,
),
vector_index_config=vector_index_config,
)

@staticmethod
def text2vec_octoai(
name: str,
Expand Down
28 changes: 28 additions & 0 deletions weaviate/collections/classes/config_vectorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Vectorizers(str, Enum):
TEXT2VEC_CONTEXTIONARY = "text2vec-contextionary"
TEXT2VEC_GPT4ALL = "text2vec-gpt4all"
TEXT2VEC_HUGGINGFACE = "text2vec-huggingface"
TEXT2VEC_MISTRAL = "text2vec-mistral"
TEXT2VEC_OCTOAI = "text2vec-octoai"
TEXT2VEC_OLLAMA = "text2vec-ollama"
TEXT2VEC_OPENAI = "text2vec-openai"
Expand Down Expand Up @@ -230,6 +231,14 @@ class _Text2VecHuggingFaceConfigCreate(_Text2VecHuggingFaceConfig, _VectorizerCo
pass


class _Text2VecMistralConfig(_VectorizerConfigCreate):
vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
default=Vectorizers.TEXT2VEC_MISTRAL, frozen=True, exclude=True
)
model: Optional[str]
vectorizeClassName: bool


OpenAIType = Literal["text", "code"]


Expand Down Expand Up @@ -805,6 +814,25 @@ def text2vec_huggingface(
vectorizeClassName=vectorize_collection_name,
)

@staticmethod
def text2vec_mistral(
*,
model: Optional[str] = None,
vectorize_collection_name: bool = True,
) -> _VectorizerConfigCreate:
"""Create a `_Text2VecMistralConfig` object for use when vectorizing using the `text2vec-mistral` model.

See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-mistral)
for detailed usage.

Arguments:
`model`
The model to use. Defaults to `None`, which uses the server-defined default.
`vectorize_collection_name`
Whether to vectorize the collection name. Defaults to `True`.
"""
return _Text2VecMistralConfig(model=model, vectorizeClassName=vectorize_collection_name)

@staticmethod
def text2vec_octoai(
*,
Expand Down
35 changes: 29 additions & 6 deletions weaviate/connect/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ class _IntegrationConfigJina(_IntegrationConfig):
base_url: Optional[str] = Field(serialization_alias="X-Jinaai-Baseurl")


class _IntegrationConfigMistral(_IntegrationConfig):
api_key: str = Field(serialization_alias="X-Mistral-Api-Key")
request_per_minute_embeddings: Optional[int] = Field(
serialization_alias="X-Mistral-Ratelimit-RequestPM-Embedding"
)
tokens_per_minute_embeddings: Optional[int] = Field(
serialization_alias="X-Mistral-Ratelimit-TokenPM-Embedding"
)


class _IntegrationConfigOcto(_IntegrationConfig):
api_key: str = Field(serialization_alias="X-OctoAI-Api-Key")
requests_per_minute_embeddings: Optional[int] = Field(
Expand All @@ -87,7 +97,7 @@ def cohere(
*,
api_key: str,
base_url: Optional[str] = None,
requests_per_minute_embeddings: Optional[int] = None
requests_per_minute_embeddings: Optional[int] = None,
) -> _IntegrationConfig:
return _IntegrationConfigCohere(
api_key=api_key,
Expand All @@ -100,7 +110,7 @@ def huggingface(
*,
api_key: str,
requests_per_minute_embeddings: Optional[int] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigHuggingface(
api_key=api_key,
Expand All @@ -115,7 +125,7 @@ def openai(
requests_per_minute_embeddings: Optional[int] = None,
tokens_per_minute_embeddings: Optional[int] = None,
organization: Optional[str] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigOpenAi(
api_key=api_key,
Expand Down Expand Up @@ -147,7 +157,7 @@ def voyageai(
api_key: str,
requests_per_minute_embeddings: Optional[int] = None,
tokens_per_minute_embeddings: Optional[int] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigVoyage(
api_key=api_key,
Expand All @@ -161,7 +171,7 @@ def jinaai(
*,
api_key: str,
requests_per_minute_embeddings: Optional[int] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigJina(
api_key=api_key,
Expand All @@ -174,10 +184,23 @@ def octoai(
*,
api_key: str,
requests_per_minute_embeddings: Optional[int] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> _IntegrationConfig:
return _IntegrationConfigOcto(
api_key=api_key,
requests_per_minute_embeddings=requests_per_minute_embeddings,
base_url=base_url,
)

@staticmethod
def mistral(
*,
api_key: str,
request_per_minute_embeddings: Optional[int] = None,
tokens_per_minute_embeddings: Optional[int] = None,
) -> _IntegrationConfig:
return _IntegrationConfigMistral(
api_key=api_key,
request_per_minute_embeddings=request_per_minute_embeddings,
tokens_per_minute_embeddings=tokens_per_minute_embeddings,
)
Loading