Skip to content

Commit

Permalink
Support MariaDB database
Browse files Browse the repository at this point in the history
MariaDB has published vector support in Preview release.
https://mariadb.com/kb/en/vector-overview/

MariaDB's vector features are currently only available in preview and
development versions. The VectorDBBench code for MariaDB is fully
functional and tested against the MariaDB commit
https://github.com/MariaDB/server/tree/1f044176ed, a more recent
development version with interface modifications and performance
enhancements. However, this client in VectorDBBench will become
practically useful with the next MariaDB major release for general
availability.

- Support MariaDB vector search with HNSW algorithm.
- Support index and search parameters:
   - storage_engine: InnoDB or MyISAM
   - M: mhnsw_max_edges_per_node
   - ef_search: mhnsw_min_limit
   - cache_size: mhnsw_cache_size
- Support CLI of `vectordbbench mariadbhnsw`.
  • Loading branch information
HugoWenTD committed Oct 18, 2024
1 parent b364fe3 commit 9151174
Show file tree
Hide file tree
Showing 8 changed files with 469 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ all = [
"psycopg-binary",
"opensearch-dsl==2.1.0",
"opensearch-py==2.6.0",
"mariadb",
]

qdrant = [ "qdrant-client" ]
Expand All @@ -78,6 +79,7 @@ memorydb = [ "memorydb" ]
chromadb = [ "chromadb" ]
awsopensearch = [ "awsopensearch" ]
zilliz_cloud = []
mariadb = [ "mariadb" ]

[project.urls]
"repository" = "https://github.com/zilliztech/VectorDBBench"
Expand Down
13 changes: 13 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DB(Enum):
MemoryDB = "MemoryDB"
Chroma = "Chroma"
AWSOpenSearch = "OpenSearch"
MariaDB = "MariaDB"
Test = "test"


Expand Down Expand Up @@ -93,6 +94,10 @@ def init_cls(self) -> Type[VectorDB]:
from .aws_opensearch.aws_opensearch import AWSOpenSearch
return AWSOpenSearch

if self == DB.MariaDB:
from .mariadb.mariadb import MariaDB
return MariaDB

@property
def config_cls(self) -> Type[DBConfig]:
"""Import while in use"""
Expand Down Expand Up @@ -148,6 +153,10 @@ def config_cls(self) -> Type[DBConfig]:
from .aws_opensearch.config import AWSOpenSearchConfig
return AWSOpenSearchConfig

if self == DB.MariaDB:
from .mariadb.config import MariaDBConfig
return MariaDBConfig

def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
if self == DB.Milvus:
from .milvus.config import _milvus_case_config
Expand Down Expand Up @@ -185,6 +194,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
from .pgvectorscale.config import _pgvectorscale_case_config
return _pgvectorscale_case_config.get(index_type)

if self == DB.MariaDB:
from .mariadb.config import _mariadb_case_config
return _mariadb_case_config.get(index_type)

# DB.Pinecone, DB.Chroma, DB.Redis
return EmptyDBCaseConfig

Expand Down
107 changes: 107 additions & 0 deletions vectordb_bench/backend/clients/mariadb/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Annotated, Optional, Unpack

import click
import os
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor1,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from vectordb_bench.backend.clients import DB


class MariaDBTypedDict(CommonTypedDict):
user_name: Annotated[
str, click.option("--username",
type=str,
help="Username",
required=True,
),
]
password: Annotated[
str, click.option("--password",
type=str,
help="Password",
required=True,
),
]

host: Annotated[
str, click.option("--host",
type=str,
help="Db host",
default="127.0.0.1",
),
]

port: Annotated[
int, click.option("--port",
type=int,
default=3306,
help="Db Port",
),
]

storage_engine: Annotated[
int, click.option("--storage-engine",
type=click.Choice(["InnoDB", "MyISAM"]),
help="DB storage engine",
required=True,
),
]

class MariaDBHNSWTypedDict(MariaDBTypedDict):
...
m: Annotated[
Optional[int], click.option("--m",
type=int,
help="MariaDB system variable mhnsw_max_edges_per_node",
required=False,
),
]

ef_search: Annotated[
Optional[int], click.option("--ef-search",
type=int,
help="MariaDB system variable mhnsw_min_limit",
required=False,
),
]

cache_size: Annotated[
Optional[int], click.option("--cache-size",
type=int,
help="MariaDB system variable mhnsw_cache_size",
required=False,
),
]


@cli.command()
@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict)
def MariaDBHNSW(
**parameters: Unpack[MariaDBHNSWTypedDict],
):
from .config import MariaDBConfig, MariaDBHNSWConfig

run(
db=DB.MariaDB,
db_config=MariaDBConfig(
db_label=parameters["db_label"],
user_name=parameters["username"],
password=SecretStr(parameters["password"]),
host=parameters["host"],
port=parameters["port"],
),
db_case_config=MariaDBHNSWConfig(
M=parameters["m"],
ef_search=parameters["ef_search"],
storage_engine=parameters["storage_engine"],
cache_size=parameters["cache_size"],
),
**parameters,
)
71 changes: 71 additions & 0 deletions vectordb_bench/backend/clients/mariadb/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from pydantic import SecretStr, BaseModel
from typing import TypedDict
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType

class MariaDBConfigDict(TypedDict):
"""These keys will be directly used as kwargs in mariadb connection string,
so the names must match exactly mariadb API"""

user: str
password: str
host: str
port: int


class MariaDBConfig(DBConfig):
user_name: str = "root"
password: SecretStr
host: str = "127.0.0.1"
port: int = 3306

def to_dict(self) -> MariaDBConfigDict:
pwd_str = self.password.get_secret_value()
return {
"host": self.host,
"port": self.port,
"user": self.user_name,
"password": pwd_str,
}


class MariaDBIndexConfig(BaseModel):
"""Base config for MariaDB"""

metric_type: MetricType | None = None

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "euclidean"
elif self.metric_type == MetricType.COSINE:
return "cosine"
else:
raise ValueError(f"Metric type {self.metric_type} is not supported!")

class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
M: int | None
ef_search: int | None
index: IndexType = IndexType.HNSW
storage_engine: str = "InnoDB"
cache_size: int | None

def index_param(self) -> dict:
return {
"storage_engine": self.storage_engine,
"metric_type": self.parse_metric(),
"index_type": self.index.value,
"M": self.M,
"cache_size": self.cache_size,
}

def search_param(self) -> dict:
return {
"metric_type": self.parse_metric(),
"ef_search": self.ef_search,
}


_mariadb_case_config = {
IndexType.HNSW: MariaDBHNSWConfig,
}


Loading

0 comments on commit 9151174

Please sign in to comment.