Skip to content

Commit

Permalink
wip: LASQ - Add LASQ quantization
Browse files Browse the repository at this point in the history
Enable the capability to configure a collection using local adaptative
SQ

Signed-off-by: Rodrigo Lopez <[email protected]>
  • Loading branch information
rlmanrique committed Nov 14, 2024
1 parent 6af9d69 commit 3bc3e45
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 4 deletions.
18 changes: 18 additions & 0 deletions integration/test_collection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from weaviate.collections.classes.config import (
_BQConfig,
_SQConfig,
_LASQConfig,
_CollectionConfig,
_CollectionConfigSimple,
_PQConfig,
Expand Down Expand Up @@ -612,6 +613,23 @@ def test_hnsw_with_sq(collection_factory: CollectionFactory) -> None:
assert isinstance(config.vector_index_config.quantizer, _SQConfig)


def test_hnsw_with_lasq(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
vector_index_config=Configure.VectorIndex.hnsw(
vector_cache_max_objects=5,
quantizer=Configure.VectorIndex.Quantizer.lasq(training_limit=1000000),
),
)
if collection._connection._weaviate_version.is_lower_than(1, 28, 0):
pytest.skip("LASQ+HNSW is not supported in Weaviate versions lower than 1.28.0")

config = collection.config.get()
assert config.vector_index_type == VectorIndexType.HNSW
assert config.vector_index_config is not None
assert isinstance(config.vector_index_config, _VectorIndexConfigHNSW)
assert isinstance(config.vector_index_config.quantizer, _LASQConfig)


@pytest.mark.parametrize(
"vector_index_config",
[
Expand Down
12 changes: 12 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,18 @@ def test_vector_config_flat_pq() -> None:
assert vi_dict["pq"]["segments"] == 789


def test_vector_config_hnsw_lasq() -> None:
vector_index = Configure.VectorIndex.hnsw(
ef_construction=128,
quantizer=Configure.VectorIndex.Quantizer.lasq(training_limit=5012),
)

vi_dict = vector_index._to_dict()

assert vi_dict["efConstruction"] == 128
assert vi_dict["sq"]["trainingLimit"] == 5012


TEST_CONFIG_WITH_NAMED_VECTORIZER_PARAMETERS = [
(
[Configure.NamedVectors.text2vec_contextionary(name="test", source_properties=["prop"])],
Expand Down
64 changes: 62 additions & 2 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,15 @@ def quantizer_name() -> str:
return "sq"


class _LASQConfigCreate(_QuantizerConfigCreate):
cache: Optional[bool]
trainingLimit: Optional[int]

@staticmethod
def quantizer_name() -> str:
return "lasq"


class _PQConfigUpdate(_QuantizerConfigUpdate):
bitCompression: Optional[bool] = Field(default=None)
centroids: Optional[int]
Expand Down Expand Up @@ -357,6 +366,15 @@ def quantizer_name() -> str:
return "sq"


class _LASQConfigUpdate(_QuantizerConfigUpdate):
enabled: Optional[bool]
trainingLimit: Optional[int]

@staticmethod
def quantizer_name() -> str:
return "lasq"


class _ShardingConfigCreate(_ConfigCreateModel):
virtualPerPhysical: Optional[int]
desiredCount: Optional[int]
Expand Down Expand Up @@ -1499,13 +1517,20 @@ class _SQConfig(_ConfigBase):
training_limit: int


@dataclass
class _LASQConfig(_ConfigBase):
cache: Optional[bool]
training_limit: int


BQConfig = _BQConfig
SQConfig = _SQConfig
LASQConfig = _LASQConfig


@dataclass
class _VectorIndexConfig(_ConfigBase):
quantizer: Optional[Union[PQConfig, BQConfig, SQConfig]]
quantizer: Optional[Union[PQConfig, BQConfig, SQConfig, LASQConfig]]

def to_dict(self) -> Dict[str, Any]:
out = super().to_dict()
Expand All @@ -1515,6 +1540,8 @@ def to_dict(self) -> Dict[str, Any]:
out["bq"] = {**out.pop("quantizer"), "enabled": True}
elif isinstance(self.quantizer, _SQConfig):
out["sq"] = {**out.pop("quantizer"), "enabled": True}
elif isinstance(self.quantizer, _LASQConfig):
out["lasq"] = {**out.pop("quantizer"), "enabled": True}
return out


Expand Down Expand Up @@ -2033,6 +2060,23 @@ def sq(
trainingLimit=training_limit,
)

@staticmethod
def lasq(
cache: Optional[bool] = None,
training_limit: Optional[int] = None,
) -> _LASQConfigCreate:
"""Create a `_LASQConfigCreate` object to be used when defining the Locally adaptive SQ(LASQ) configuration of Weaviate.
Use this method when defining the `quantizer` argument in the `vector_index` configuration. Note that the arguments have no effect for HNSW.
Arguments:
See [the docs](https://weaviate.io/developers/weaviate/concepts/vector-index#binary-quantization) for a more detailed view!
""" # noqa: D417 (missing argument descriptions in the docstring)
return _LASQConfigCreate(
cache=cache,
trainingLimit=training_limit,
)


class _VectorIndex:
Quantizer = _VectorIndexQuantizer
Expand Down Expand Up @@ -2319,6 +2363,20 @@ def sq(
enabled=enabled, rescoreLimit=rescore_limit, trainingLimit=training_limit
)

@staticmethod
def lasq(
training_limit: Optional[int] = None,
enabled: bool = True,
) -> _LASQConfigUpdate:
"""Create a `_LASQConfigUpdate` object to be used when updating the Locally adaptive SQ(LASQ) configuration of Weaviate.
Use this method when defining the `quantizer` argument in the `vector_index` configuration in `collection.update()`.
Arguments:
See [the docs](https://weaviate.io/developers/weaviate/concepts/vector-index#hnsw-with-compression) for a more detailed view!
""" # noqa: D417 (missing argument descriptions in the docstring)
return _LASQConfigUpdate(enabled=enabled, trainingLimit=training_limit)


class _VectorIndexUpdate:
Quantizer = _VectorIndexQuantizerUpdate
Expand All @@ -2332,7 +2390,9 @@ def hnsw(
flat_search_cutoff: Optional[int] = None,
filter_strategy: Optional[VectorFilterStrategy] = None,
vector_cache_max_objects: Optional[int] = None,
quantizer: Optional[Union[_PQConfigUpdate, _BQConfigUpdate, _SQConfigUpdate]] = None,
quantizer: Optional[
Union[_PQConfigUpdate, _BQConfigUpdate, _SQConfigUpdate, _LASQConfigUpdate]
] = None,
) -> _VectorIndexConfigHNSWUpdate:
"""Create an `_VectorIndexConfigHNSWUpdate` object to update the configuration of the HNSW vector index.
Expand Down
10 changes: 8 additions & 2 deletions weaviate/collections/classes/config_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from weaviate.collections.classes.config import (
_BQConfig,
_SQConfig,
_LASQConfig,
_CollectionConfig,
_CollectionConfigSimple,
_NamedVectorConfig,
Expand Down Expand Up @@ -117,8 +118,8 @@ def __get_vector_index_type(schema: Dict[str, Any]) -> Optional[VectorIndexType]

def __get_quantizer_config(
config: Dict[str, Any]
) -> Optional[Union[_PQConfig, _BQConfig, _SQConfig]]:
quantizer: Optional[Union[_PQConfig, _BQConfig, _SQConfig]] = None
) -> Optional[Union[_PQConfig, _BQConfig, _SQConfig, _LASQConfig]]:
quantizer: Optional[Union[_PQConfig, _BQConfig, _SQConfig, _LASQConfig]] = None
if "bq" in config and config["bq"]["enabled"]:
# values are not present for bq+hnsw
quantizer = _BQConfig(
Expand All @@ -145,6 +146,11 @@ def __get_quantizer_config(
),
),
)
elif "lasq" in config and config["lasq"].get("enabled"):
quantizer = _LASQConfig(
cache=config["lasq"].get("cache"),
training_limit=config["lasq"].get("trainingLimit"),
)
return quantizer


Expand Down

0 comments on commit 3bc3e45

Please sign in to comment.