-
Notifications
You must be signed in to change notification settings - Fork 161
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MariaDB has published vector support in Preview release. https://mariadb.com/kb/en/vector-overview/ MariaDB's vector features are currently only available in preview and development versions. The VectorDBBench code for MariaDB is fully functional and tested against the MariaDB commit https://github.com/MariaDB/server/tree/1f044176ed, a more recent development version with interface modifications and performance enhancements. However, this client in VectorDBBench will become practically useful with the next MariaDB major release for general availability. - Support MariaDB vector search with HNSW algorithm. - Support index and search parameters: - storage_engine: InnoDB or MyISAM - M: mhnsw_max_edges_per_node - ef_search: mhnsw_min_limit - cache_size: mhnsw_cache_size - Support CLI of `vectordbbench mariadbhnsw`.
- Loading branch information
Showing
8 changed files
with
469 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Annotated, Optional, Unpack | ||
|
||
import click | ||
import os | ||
from pydantic import SecretStr | ||
|
||
from ....cli.cli import ( | ||
CommonTypedDict, | ||
HNSWFlavor1, | ||
cli, | ||
click_parameter_decorators_from_typed_dict, | ||
run, | ||
) | ||
from vectordb_bench.backend.clients import DB | ||
|
||
|
||
class MariaDBTypedDict(CommonTypedDict): | ||
user_name: Annotated[ | ||
str, click.option("--username", | ||
type=str, | ||
help="Username", | ||
required=True, | ||
), | ||
] | ||
password: Annotated[ | ||
str, click.option("--password", | ||
type=str, | ||
help="Password", | ||
required=True, | ||
), | ||
] | ||
|
||
host: Annotated[ | ||
str, click.option("--host", | ||
type=str, | ||
help="Db host", | ||
default="127.0.0.1", | ||
), | ||
] | ||
|
||
port: Annotated[ | ||
int, click.option("--port", | ||
type=int, | ||
default=3306, | ||
help="Db Port", | ||
), | ||
] | ||
|
||
storage_engine: Annotated[ | ||
int, click.option("--storage-engine", | ||
type=click.Choice(["InnoDB", "MyISAM"]), | ||
help="DB storage engine", | ||
required=True, | ||
), | ||
] | ||
|
||
class MariaDBHNSWTypedDict(MariaDBTypedDict): | ||
... | ||
m: Annotated[ | ||
Optional[int], click.option("--m", | ||
type=int, | ||
help="MariaDB system variable mhnsw_max_edges_per_node", | ||
required=False, | ||
), | ||
] | ||
|
||
ef_search: Annotated[ | ||
Optional[int], click.option("--ef-search", | ||
type=int, | ||
help="MariaDB system variable mhnsw_min_limit", | ||
required=False, | ||
), | ||
] | ||
|
||
cache_size: Annotated[ | ||
Optional[int], click.option("--cache-size", | ||
type=int, | ||
help="MariaDB system variable mhnsw_cache_size", | ||
required=False, | ||
), | ||
] | ||
|
||
|
||
@cli.command() | ||
@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict) | ||
def MariaDBHNSW( | ||
**parameters: Unpack[MariaDBHNSWTypedDict], | ||
): | ||
from .config import MariaDBConfig, MariaDBHNSWConfig | ||
|
||
run( | ||
db=DB.MariaDB, | ||
db_config=MariaDBConfig( | ||
db_label=parameters["db_label"], | ||
user_name=parameters["username"], | ||
password=SecretStr(parameters["password"]), | ||
host=parameters["host"], | ||
port=parameters["port"], | ||
), | ||
db_case_config=MariaDBHNSWConfig( | ||
M=parameters["m"], | ||
ef_search=parameters["ef_search"], | ||
storage_engine=parameters["storage_engine"], | ||
cache_size=parameters["cache_size"], | ||
), | ||
**parameters, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from pydantic import SecretStr, BaseModel | ||
from typing import TypedDict | ||
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType | ||
|
||
class MariaDBConfigDict(TypedDict): | ||
"""These keys will be directly used as kwargs in mariadb connection string, | ||
so the names must match exactly mariadb API""" | ||
|
||
user: str | ||
password: str | ||
host: str | ||
port: int | ||
|
||
|
||
class MariaDBConfig(DBConfig): | ||
user_name: str = "root" | ||
password: SecretStr | ||
host: str = "127.0.0.1" | ||
port: int = 3306 | ||
|
||
def to_dict(self) -> MariaDBConfigDict: | ||
pwd_str = self.password.get_secret_value() | ||
return { | ||
"host": self.host, | ||
"port": self.port, | ||
"user": self.user_name, | ||
"password": pwd_str, | ||
} | ||
|
||
|
||
class MariaDBIndexConfig(BaseModel): | ||
"""Base config for MariaDB""" | ||
|
||
metric_type: MetricType | None = None | ||
|
||
def parse_metric(self) -> str: | ||
if self.metric_type == MetricType.L2: | ||
return "euclidean" | ||
elif self.metric_type == MetricType.COSINE: | ||
return "cosine" | ||
else: | ||
raise ValueError(f"Metric type {self.metric_type} is not supported!") | ||
|
||
class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): | ||
M: int | None | ||
ef_search: int | None | ||
index: IndexType = IndexType.HNSW | ||
storage_engine: str = "InnoDB" | ||
cache_size: int | None | ||
|
||
def index_param(self) -> dict: | ||
return { | ||
"storage_engine": self.storage_engine, | ||
"metric_type": self.parse_metric(), | ||
"index_type": self.index.value, | ||
"M": self.M, | ||
"cache_size": self.cache_size, | ||
} | ||
|
||
def search_param(self) -> dict: | ||
return { | ||
"metric_type": self.parse_metric(), | ||
"ef_search": self.ef_search, | ||
} | ||
|
||
|
||
_mariadb_case_config = { | ||
IndexType.HNSW: MariaDBHNSWConfig, | ||
} | ||
|
||
|
Oops, something went wrong.