From e98b0aa19b3f4fcafaadf18aa9f0d09d60b4dfda Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Wed, 25 Sep 2024 21:40:16 +0200 Subject: [PATCH] Refactor --- .../collections/batch/grpc_batch_objects.py | 43 ++++------- weaviate/collections/data/data.py | 75 +++++++++++++------ 2 files changed, 67 insertions(+), 51 deletions(-) diff --git a/weaviate/collections/batch/grpc_batch_objects.py b/weaviate/collections/batch/grpc_batch_objects.py index 2716c783d..92b8dba4d 100644 --- a/weaviate/collections/batch/grpc_batch_objects.py +++ b/weaviate/collections/batch/grpc_batch_objects.py @@ -2,7 +2,7 @@ import struct import time import uuid as uuid_package -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from grpc.aio import AioRpcError # type: ignore from google.protobuf.struct_pb2 import Struct @@ -24,17 +24,7 @@ WeaviateInvalidInputError, ) from weaviate.proto.v1 import batch_pb2, base_pb2 -from weaviate.util import _datetime_to_string, _get_vector_v4 - - -def _pack_named_vectors(vectors: Dict[str, List[float]]) -> List[base_pb2.Vectors]: - return [ - base_pb2.Vectors( - name=name, - vector_bytes=struct.pack("{}f".format(len(vector)), *vector), - ) - for name, vector in vectors.items() - ] +from weaviate.util import _datetime_to_string class _BatchGRPC(_BaseGRPC): @@ -47,11 +37,10 @@ class _BatchGRPC(_BaseGRPC): def __init__(self, connection: ConnectionV4, consistency_level: Optional[ConsistencyLevel]): super().__init__(connection, consistency_level) - def __grpc_objects(self, objects: List[_BatchObject]) -> List[batch_pb2.BatchObject]: - def pack_vector(vector: Any) -> bytes: - vector_list = _get_vector_v4(vector) - return struct.pack("{}f".format(len(vector_list)), *vector_list) - + def __grpc_objects( + self, + objects: List[Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]], + ) -> List[batch_pb2.BatchObject]: return [ batch_pb2.BatchObject( collection=obj.collection, @@ -65,22 +54,16 @@ def pack_vector(vector: Any) -> bytes: else None ), tenant=obj.tenant, - vector_bytes=( - pack_vector(obj.vector) - if obj.vector is not None and not isinstance(obj.vector, dict) - else None - ), - vectors=( - _pack_named_vectors(obj.vector) - if obj.vector is not None and isinstance(obj.vector, dict) - else None - ), + vector_bytes=vector_bytes, + vectors=vectors, ) - for obj in objects + for obj, vector_bytes, vectors in objects ] async def objects( - self, objects: List[_BatchObject], timeout: Union[int, float] + self, + objects: List[Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]], + timeout: Union[int, float] ) -> BatchObjectReturn: """Insert multiple objects into Weaviate through the gRPC API. @@ -114,7 +97,7 @@ async def objects( return_errors: Dict[int, ErrorObject] = {} for idx, weav_obj in enumerate(weaviate_objs): - obj = objects[idx] + obj = objects[idx][0] if idx in errors: error = ErrorObject(errors[idx], obj, original_uuid=obj.uuid) return_errors[obj.index] = error diff --git a/weaviate/collections/data/data.py b/weaviate/collections/data/data.py index b86d0ac94..b76a794b1 100644 --- a/weaviate/collections/data/data.py +++ b/weaviate/collections/data/data.py @@ -1,5 +1,6 @@ import asyncio import datetime +import struct import uuid as uuid_package from typing import ( Dict, @@ -47,6 +48,7 @@ from weaviate.connect import ConnectionV4 from weaviate.connect.v4 import _ExpectedStatusCodes from weaviate.logger import logger +from weaviate.proto.v1 import base_pb2 from weaviate.types import BEACON, UUID, VECTORS from weaviate.util import _datetime_to_string, _get_vector_v4 from weaviate.validator import _validate_input, _ValidateArgument @@ -57,6 +59,21 @@ from weaviate.exceptions import WeaviateInvalidInputError +def _pack_named_vectors(vectors: Dict[str, List[float]]) -> List[base_pb2.Vectors]: + return [ + base_pb2.Vectors( + name=name, + vector_bytes=struct.pack("{}f".format(len(vector)), *vector), + ) + for name, vector in vectors.items() + ] + + +def _pack_vector(vector: Any) -> bytes: + vector_list = _get_vector_v4(vector) + return struct.pack("{}f".format(len(vector_list)), *vector_list) + + class _DataBase: def __init__( self, @@ -281,6 +298,42 @@ def with_data_model(self, data_model: Type[TProperties]) -> "_DataCollectionAsyn data_model, ) + def __validate_vector( + self, + idx: int, + obj: Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]] + ) -> Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]: + if isinstance(obj, DataObject): + vector_bytes = ( + _pack_vector(obj.vector) + if obj.vector is not None and not isinstance(obj.vector, dict) + else None + ) + vectors = ( + _pack_named_vectors(obj.vector) + if obj.vector is not None and isinstance(obj.vector, dict) + else None + ) + return _BatchObject( + collection=self.name, + vector=obj.vector, + uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()), + properties=cast(dict, obj.properties), + tenant=self._tenant, + references=obj.references, + index=idx, + ), vector_bytes, vectors + + return _BatchObject( + collection=self.name, + vector=None, + uuid=str(uuid_package.uuid4()), + properties=cast(dict, obj), + tenant=self._tenant, + references=None, + index=idx, + ), None, None + def __parse_vector(self, obj: Dict[str, Any], vector: VECTORS) -> Dict[str, Any]: if isinstance(vector, dict): obj["vectors"] = {key: _get_vector_v4(val) for key, val in vector.items()} @@ -360,27 +413,7 @@ async def insert_many( If every object in the batch fails to be inserted. The exception message contains details about the failure. """ objs = [ - ( - _BatchObject( - collection=self.name, - vector=obj.vector, - uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()), - properties=cast(dict, obj.properties), - tenant=self._tenant, - references=obj.references, - index=idx, - ) - if isinstance(obj, DataObject) - else _BatchObject( - collection=self.name, - vector=None, - uuid=str(uuid_package.uuid4()), - properties=cast(dict, obj), - tenant=self._tenant, - references=None, - index=idx, - ) - ) + self.__validate_vector(idx, obj) for idx, obj in enumerate(objects) ] res = await self._batch_grpc.objects(objs, timeout=self._connection.timeout_config.insert)