Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Sep 25, 2024
1 parent 9f21009 commit e98b0aa
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 51 deletions.
43 changes: 13 additions & 30 deletions weaviate/collections/batch/grpc_batch_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import struct
import time
import uuid as uuid_package
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from grpc.aio import AioRpcError # type: ignore
from google.protobuf.struct_pb2 import Struct
Expand All @@ -24,17 +24,7 @@
WeaviateInvalidInputError,
)
from weaviate.proto.v1 import batch_pb2, base_pb2
from weaviate.util import _datetime_to_string, _get_vector_v4


def _pack_named_vectors(vectors: Dict[str, List[float]]) -> List[base_pb2.Vectors]:
return [
base_pb2.Vectors(
name=name,
vector_bytes=struct.pack("{}f".format(len(vector)), *vector),
)
for name, vector in vectors.items()
]
from weaviate.util import _datetime_to_string


class _BatchGRPC(_BaseGRPC):
Expand All @@ -47,11 +37,10 @@ class _BatchGRPC(_BaseGRPC):
def __init__(self, connection: ConnectionV4, consistency_level: Optional[ConsistencyLevel]):
super().__init__(connection, consistency_level)

def __grpc_objects(self, objects: List[_BatchObject]) -> List[batch_pb2.BatchObject]:
def pack_vector(vector: Any) -> bytes:
vector_list = _get_vector_v4(vector)
return struct.pack("{}f".format(len(vector_list)), *vector_list)

def __grpc_objects(
self,
objects: List[Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]],
) -> List[batch_pb2.BatchObject]:
return [
batch_pb2.BatchObject(
collection=obj.collection,
Expand All @@ -65,22 +54,16 @@ def pack_vector(vector: Any) -> bytes:
else None
),
tenant=obj.tenant,
vector_bytes=(
pack_vector(obj.vector)
if obj.vector is not None and not isinstance(obj.vector, dict)
else None
),
vectors=(
_pack_named_vectors(obj.vector)
if obj.vector is not None and isinstance(obj.vector, dict)
else None
),
vector_bytes=vector_bytes,
vectors=vectors,
)
for obj in objects
for obj, vector_bytes, vectors in objects
]

async def objects(
self, objects: List[_BatchObject], timeout: Union[int, float]
self,
objects: List[Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]],
timeout: Union[int, float]
) -> BatchObjectReturn:
"""Insert multiple objects into Weaviate through the gRPC API.
Expand Down Expand Up @@ -114,7 +97,7 @@ async def objects(
return_errors: Dict[int, ErrorObject] = {}

for idx, weav_obj in enumerate(weaviate_objs):
obj = objects[idx]
obj = objects[idx][0]
if idx in errors:
error = ErrorObject(errors[idx], obj, original_uuid=obj.uuid)
return_errors[obj.index] = error
Expand Down
75 changes: 54 additions & 21 deletions weaviate/collections/data/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import datetime
import struct
import uuid as uuid_package
from typing import (
Dict,
Expand Down Expand Up @@ -47,6 +48,7 @@
from weaviate.connect import ConnectionV4
from weaviate.connect.v4 import _ExpectedStatusCodes
from weaviate.logger import logger
from weaviate.proto.v1 import base_pb2
from weaviate.types import BEACON, UUID, VECTORS
from weaviate.util import _datetime_to_string, _get_vector_v4
from weaviate.validator import _validate_input, _ValidateArgument
Expand All @@ -57,6 +59,21 @@
from weaviate.exceptions import WeaviateInvalidInputError


def _pack_named_vectors(vectors: Dict[str, List[float]]) -> List[base_pb2.Vectors]:
return [
base_pb2.Vectors(
name=name,
vector_bytes=struct.pack("{}f".format(len(vector)), *vector),
)
for name, vector in vectors.items()
]


def _pack_vector(vector: Any) -> bytes:
vector_list = _get_vector_v4(vector)
return struct.pack("{}f".format(len(vector_list)), *vector_list)


class _DataBase:
def __init__(
self,
Expand Down Expand Up @@ -281,6 +298,42 @@ def with_data_model(self, data_model: Type[TProperties]) -> "_DataCollectionAsyn
data_model,
)

def __validate_vector(
self,
idx: int,
obj: Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]
) -> Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]:
if isinstance(obj, DataObject):
vector_bytes = (
_pack_vector(obj.vector)
if obj.vector is not None and not isinstance(obj.vector, dict)
else None
)
vectors = (
_pack_named_vectors(obj.vector)
if obj.vector is not None and isinstance(obj.vector, dict)
else None
)
return _BatchObject(
collection=self.name,
vector=obj.vector,
uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()),
properties=cast(dict, obj.properties),
tenant=self._tenant,
references=obj.references,
index=idx,
), vector_bytes, vectors

return _BatchObject(
collection=self.name,
vector=None,
uuid=str(uuid_package.uuid4()),
properties=cast(dict, obj),
tenant=self._tenant,
references=None,
index=idx,
), None, None

def __parse_vector(self, obj: Dict[str, Any], vector: VECTORS) -> Dict[str, Any]:
if isinstance(vector, dict):
obj["vectors"] = {key: _get_vector_v4(val) for key, val in vector.items()}
Expand Down Expand Up @@ -360,27 +413,7 @@ async def insert_many(
If every object in the batch fails to be inserted. The exception message contains details about the failure.
"""
objs = [
(
_BatchObject(
collection=self.name,
vector=obj.vector,
uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()),
properties=cast(dict, obj.properties),
tenant=self._tenant,
references=obj.references,
index=idx,
)
if isinstance(obj, DataObject)
else _BatchObject(
collection=self.name,
vector=None,
uuid=str(uuid_package.uuid4()),
properties=cast(dict, obj),
tenant=self._tenant,
references=None,
index=idx,
)
)
self.__validate_vector(idx, obj)
for idx, obj in enumerate(objects)
]
res = await self._batch_grpc.objects(objs, timeout=self._connection.timeout_config.insert)
Expand Down

0 comments on commit e98b0aa

Please sign in to comment.