Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community(vectorstore): fix CosmosDB NoSQL create_container #29720

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def __init__(
*,
cosmos_client: CosmosClient,
embedding: Embeddings,
vector_embedding_policy: Dict[str, Any],
indexing_policy: Dict[str, Any],
cosmos_container_properties: Dict[str, Any],
cosmos_database_properties: Dict[str, Any],
vector_embedding_policy: Optional[Dict[str, Any]] = None,
indexing_policy: Optional[Dict[str, Any]] = None,
cosmos_container_properties: Optional[Dict[str, Any]] = None,
cosmos_database_properties: Optional[Dict[str, Any]] = None,
full_text_policy: Optional[Dict[str, Any]] = None,
database_name: str = "vectorSearchDB",
container_name: str = "vectorSearchContainer",
Expand Down Expand Up @@ -107,6 +107,20 @@ def __init__(
self._full_text_search_enabled = full_text_search_enabled

if self._create_container:
if self._cosmos_database_properties is None:
raise ValueError(
"cosmos_database_properties cannot be null in creation mode."
)
if self._cosmos_container_properties is None:
raise ValueError(
"cosmos_container_properties cannot be null in creation mode."
)
if self._indexing_policy is None:
raise ValueError("indexing_policy cannot be null in creation mode.")
if self._vector_embedding_policy is None:
raise ValueError(
"vector_embedding_policy cannot be null in creation mode."
)
if (
self._indexing_policy["vectorIndexes"] is None
or len(self._indexing_policy["vectorIndexes"]) == 0
Expand All @@ -115,8 +129,8 @@ def __init__(
"vectorIndexes cannot be null or empty in the indexing_policy."
)
if (
self._vector_embedding_policy is None
or len(vector_embedding_policy["vectorEmbeddings"]) == 0
self._vector_embedding_policy["vectorEmbeddings"] is None
or len(self._vector_embedding_policy["vectorEmbeddings"]) == 0
):
raise ValueError(
"vectorEmbeddings cannot be null "
Expand Down Expand Up @@ -144,42 +158,57 @@ def __init__(
"full_text_policy if full text search is enabled."
)

# Create the database if it already doesn't exist
self._database = self._cosmos_client.create_database_if_not_exists(
id=self._database_name,
offer_throughput=self._cosmos_database_properties.get("offer_throughput"),
session_token=self._cosmos_database_properties.get("session_token"),
initial_headers=self._cosmos_database_properties.get("initial_headers"),
etag=self._cosmos_database_properties.get("etag"),
match_condition=self._cosmos_database_properties.get("match_condition"),
)
# Create the database if it already doesn't exist
self._database = self._cosmos_client.create_database_if_not_exists(
id=self._database_name,
offer_throughput=self._cosmos_database_properties.get(
"offer_throughput"
),
session_token=self._cosmos_database_properties.get("session_token"),
initial_headers=self._cosmos_database_properties.get("initial_headers"),
etag=self._cosmos_database_properties.get("etag"),
match_condition=self._cosmos_database_properties.get("match_condition"),
)

# Create the collection if it already doesn't exist
self._container = self._database.create_container_if_not_exists(
id=self._container_name,
partition_key=self._cosmos_container_properties["partition_key"],
indexing_policy=self._indexing_policy,
default_ttl=self._cosmos_container_properties.get("default_ttl"),
offer_throughput=self._cosmos_container_properties.get("offer_throughput"),
unique_key_policy=self._cosmos_container_properties.get(
"unique_key_policy"
),
conflict_resolution_policy=self._cosmos_container_properties.get(
"conflict_resolution_policy"
),
analytical_storage_ttl=self._cosmos_container_properties.get(
"analytical_storage_ttl"
),
computed_properties=self._cosmos_container_properties.get(
"computed_properties"
),
etag=self._cosmos_container_properties.get("etag"),
match_condition=self._cosmos_container_properties.get("match_condition"),
session_token=self._cosmos_container_properties.get("session_token"),
initial_headers=self._cosmos_container_properties.get("initial_headers"),
vector_embedding_policy=self._vector_embedding_policy,
full_text_policy=self._full_text_policy,
)
# Create the collection if it already doesn't exist
self._container = self._database.create_container_if_not_exists(
id=self._container_name,
partition_key=self._cosmos_container_properties["partition_key"],
indexing_policy=self._indexing_policy,
default_ttl=self._cosmos_container_properties.get("default_ttl"),
offer_throughput=self._cosmos_container_properties.get(
"offer_throughput"
),
unique_key_policy=self._cosmos_container_properties.get(
"unique_key_policy"
),
conflict_resolution_policy=self._cosmos_container_properties.get(
"conflict_resolution_policy"
),
analytical_storage_ttl=self._cosmos_container_properties.get(
"analytical_storage_ttl"
),
computed_properties=self._cosmos_container_properties.get(
"computed_properties"
),
etag=self._cosmos_container_properties.get("etag"),
match_condition=self._cosmos_container_properties.get(
"match_condition"
),
session_token=self._cosmos_container_properties.get("session_token"),
initial_headers=self._cosmos_container_properties.get(
"initial_headers"
),
vector_embedding_policy=self._vector_embedding_policy,
full_text_policy=self._full_text_policy,
)
else:
self._database = self._cosmos_client.get_database_client(
database=self._database_name
)
self._container = self._database.get_container_client(
container=self._container_name
)

def add_texts(
self,
Expand Down
Loading