Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: LASQ - Add LASQ quantization #1409

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions integration/test_collection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from weaviate.collections.classes.config import (
_BQConfig,
_SQConfig,
_LASQConfig,
_CollectionConfig,
_CollectionConfigSimple,
_PQConfig,
Expand Down Expand Up @@ -612,6 +613,23 @@ def test_hnsw_with_sq(collection_factory: CollectionFactory) -> None:
assert isinstance(config.vector_index_config.quantizer, _SQConfig)


def test_hnsw_with_lasq(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
vector_index_config=Configure.VectorIndex.hnsw(
vector_cache_max_objects=5,
quantizer=Configure.VectorIndex.Quantizer.lasq(training_limit=1000000),
),
)
if collection._connection._weaviate_version.is_lower_than(1, 28, 0):
pytest.skip("LASQ+HNSW is not supported in Weaviate versions lower than 1.28.0")

config = collection.config.get()
assert config.vector_index_type == VectorIndexType.HNSW
assert config.vector_index_config is not None
assert isinstance(config.vector_index_config, _VectorIndexConfigHNSW)
assert isinstance(config.vector_index_config.quantizer, _LASQConfig)


@pytest.mark.parametrize(
"vector_index_config",
[
Expand Down
12 changes: 12 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,18 @@ def test_vector_config_flat_pq() -> None:
assert vi_dict["pq"]["segments"] == 789


def test_vector_config_hnsw_lasq() -> None:
vector_index = Configure.VectorIndex.hnsw(
ef_construction=128,
quantizer=Configure.VectorIndex.Quantizer.lasq(training_limit=5012),
)

vi_dict = vector_index._to_dict()

assert vi_dict["efConstruction"] == 128
assert vi_dict["lasq"]["trainingLimit"] == 5012


TEST_CONFIG_WITH_NAMED_VECTORIZER_PARAMETERS = [
(
[Configure.NamedVectors.text2vec_contextionary(name="test", source_properties=["prop"])],
Expand Down
64 changes: 62 additions & 2 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,15 @@ def quantizer_name() -> str:
return "sq"


class _LASQConfigCreate(_QuantizerConfigCreate):
cache: Optional[bool]
trainingLimit: Optional[int]

@staticmethod
def quantizer_name() -> str:
return "lasq"


class _PQConfigUpdate(_QuantizerConfigUpdate):
bitCompression: Optional[bool] = Field(default=None)
centroids: Optional[int]
Expand Down Expand Up @@ -357,6 +366,15 @@ def quantizer_name() -> str:
return "sq"


class _LASQConfigUpdate(_QuantizerConfigUpdate):
enabled: Optional[bool]
trainingLimit: Optional[int]

@staticmethod
def quantizer_name() -> str:
return "lasq"


class _ShardingConfigCreate(_ConfigCreateModel):
virtualPerPhysical: Optional[int]
desiredCount: Optional[int]
Expand Down Expand Up @@ -1499,13 +1517,20 @@ class _SQConfig(_ConfigBase):
training_limit: int


@dataclass
class _LASQConfig(_ConfigBase):
cache: Optional[bool]
training_limit: int


BQConfig = _BQConfig
SQConfig = _SQConfig
LASQConfig = _LASQConfig


@dataclass
class _VectorIndexConfig(_ConfigBase):
quantizer: Optional[Union[PQConfig, BQConfig, SQConfig]]
quantizer: Optional[Union[PQConfig, BQConfig, SQConfig, LASQConfig]]

def to_dict(self) -> Dict[str, Any]:
out = super().to_dict()
Expand All @@ -1515,6 +1540,8 @@ def to_dict(self) -> Dict[str, Any]:
out["bq"] = {**out.pop("quantizer"), "enabled": True}
elif isinstance(self.quantizer, _SQConfig):
out["sq"] = {**out.pop("quantizer"), "enabled": True}
elif isinstance(self.quantizer, _LASQConfig):
out["lasq"] = {**out.pop("quantizer"), "enabled": True}
return out


Expand Down Expand Up @@ -2033,6 +2060,23 @@ def sq(
trainingLimit=training_limit,
)

@staticmethod
def lasq(
cache: Optional[bool] = None,
training_limit: Optional[int] = None,
) -> _LASQConfigCreate:
"""Create a `_LASQConfigCreate` object to be used when defining the Locally adaptive SQ(LASQ) configuration of Weaviate.

Use this method when defining the `quantizer` argument in the `vector_index` configuration. Note that the arguments have no effect for HNSW.

Arguments:
See [the docs](https://weaviate.io/developers/weaviate/concepts/vector-index#binary-quantization) for a more detailed view!
""" # noqa: D417 (missing argument descriptions in the docstring)
return _LASQConfigCreate(
cache=cache,
trainingLimit=training_limit,
)


class _VectorIndex:
Quantizer = _VectorIndexQuantizer
Expand Down Expand Up @@ -2319,6 +2363,20 @@ def sq(
enabled=enabled, rescoreLimit=rescore_limit, trainingLimit=training_limit
)

@staticmethod
def lasq(
training_limit: Optional[int] = None,
enabled: bool = True,
) -> _LASQConfigUpdate:
"""Create a `_LASQConfigUpdate` object to be used when updating the Locally adaptive SQ(LASQ) configuration of Weaviate.

Use this method when defining the `quantizer` argument in the `vector_index` configuration in `collection.update()`.

Arguments:
See [the docs](https://weaviate.io/developers/weaviate/concepts/vector-index#hnsw-with-compression) for a more detailed view!
""" # noqa: D417 (missing argument descriptions in the docstring)
return _LASQConfigUpdate(enabled=enabled, trainingLimit=training_limit)


class _VectorIndexUpdate:
Quantizer = _VectorIndexQuantizerUpdate
Expand All @@ -2332,7 +2390,9 @@ def hnsw(
flat_search_cutoff: Optional[int] = None,
filter_strategy: Optional[VectorFilterStrategy] = None,
vector_cache_max_objects: Optional[int] = None,
quantizer: Optional[Union[_PQConfigUpdate, _BQConfigUpdate, _SQConfigUpdate]] = None,
quantizer: Optional[
Union[_PQConfigUpdate, _BQConfigUpdate, _SQConfigUpdate, _LASQConfigUpdate]
] = None,
) -> _VectorIndexConfigHNSWUpdate:
"""Create an `_VectorIndexConfigHNSWUpdate` object to update the configuration of the HNSW vector index.

Expand Down
10 changes: 8 additions & 2 deletions weaviate/collections/classes/config_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from weaviate.collections.classes.config import (
_BQConfig,
_SQConfig,
_LASQConfig,
_CollectionConfig,
_CollectionConfigSimple,
_NamedVectorConfig,
Expand Down Expand Up @@ -117,8 +118,8 @@ def __get_vector_index_type(schema: Dict[str, Any]) -> Optional[VectorIndexType]

def __get_quantizer_config(
config: Dict[str, Any]
) -> Optional[Union[_PQConfig, _BQConfig, _SQConfig]]:
quantizer: Optional[Union[_PQConfig, _BQConfig, _SQConfig]] = None
) -> Optional[Union[_PQConfig, _BQConfig, _SQConfig, _LASQConfig]]:
quantizer: Optional[Union[_PQConfig, _BQConfig, _SQConfig, _LASQConfig]] = None
if "bq" in config and config["bq"]["enabled"]:
# values are not present for bq+hnsw
quantizer = _BQConfig(
Expand All @@ -145,6 +146,11 @@ def __get_quantizer_config(
),
),
)
elif "lasq" in config and config["lasq"].get("enabled"):
quantizer = _LASQConfig(
cache=config["lasq"].get("cache"),
training_limit=config["lasq"].get("trainingLimit"),
)
return quantizer


Expand Down