Skip to content

Commit

Permalink
test: schema for pytest (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 authored Mar 10, 2024
1 parent 7ea10b8 commit 0c98219
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Build and release JAR files
on:
release:
types:
- created
- created

jobs:
upload-jar:
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
<dependency>
<groupId>io.qdrant</groupId>
<artifactId>client</artifactId>
<version>1.7.2</version>
<version>1.8.0</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
Expand Down
106 changes: 106 additions & 0 deletions src/test/python/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from pyspark.sql.types import (
StructType,
StructField,
StringType,
IntegerType,
DoubleType,
ArrayType,
FloatType,
)

hair_schema = StructType(
[
StructField("color", StringType(), nullable=True),
StructField("type", StringType(), nullable=True),
]
)

coordinates_schema = StructType(
[
StructField("lat", DoubleType(), nullable=True),
StructField("lng", DoubleType(), nullable=True),
]
)

address_schema = StructType(
[
StructField("address", StringType(), nullable=True),
StructField("city", StringType(), nullable=True),
StructField("coordinates", coordinates_schema, nullable=True),
StructField("postalCode", StringType(), nullable=True),
StructField("state", StringType(), nullable=True),
]
)

bank_schema = StructType(
[
StructField("cardExpire", StringType(), nullable=True),
StructField("cardNumber", StringType(), nullable=True),
StructField("cardType", StringType(), nullable=True),
StructField("currency", StringType(), nullable=True),
StructField("iban", StringType(), nullable=True),
]
)

company_address_schema = StructType(
[
StructField("address", StringType(), nullable=True),
StructField("city", StringType(), nullable=True),
StructField("coordinates", coordinates_schema, nullable=True),
StructField("postalCode", StringType(), nullable=True),
StructField("state", StringType(), nullable=True),
]
)

company_schema = StructType(
[
StructField("address", company_address_schema, nullable=True),
StructField("department", StringType(), nullable=True),
StructField("name", StringType(), nullable=True),
StructField("title", StringType(), nullable=True),
]
)

crypto_schema = StructType(
[
StructField("coin", StringType(), nullable=True),
StructField("wallet", StringType(), nullable=True),
StructField("network", StringType(), nullable=True),
]
)

schema = StructType(
[
StructField("id", IntegerType(), nullable=True),
StructField("firstName", StringType(), nullable=True),
StructField("lastName", StringType(), nullable=True),
StructField("maidenName", StringType(), nullable=True),
StructField("age", IntegerType(), nullable=True),
StructField("gender", StringType(), nullable=True),
StructField("email", StringType(), nullable=True),
StructField("phone", StringType(), nullable=True),
StructField("username", StringType(), nullable=True),
StructField("password", StringType(), nullable=True),
StructField("birthDate", StringType(), nullable=True),
StructField("image", StringType(), nullable=True),
StructField("bloodGroup", StringType(), nullable=True),
StructField("height", DoubleType(), nullable=True),
StructField("weight", DoubleType(), nullable=True),
StructField("eyeColor", StringType(), nullable=True),
StructField("hair", hair_schema, nullable=True),
StructField("domain", StringType(), nullable=True),
StructField("ip", StringType(), nullable=True),
StructField("address", address_schema, nullable=True),
StructField("macAddress", StringType(), nullable=True),
StructField("university", StringType(), nullable=True),
StructField("bank", bank_schema, nullable=True),
StructField("company", company_schema, nullable=True),
StructField("ein", StringType(), nullable=True),
StructField("ssn", StringType(), nullable=True),
StructField("userAgent", StringType(), nullable=True),
StructField("crypto", crypto_schema, nullable=True),
StructField("dense_vector", ArrayType(FloatType()), nullable=False),
StructField("sparse_indices", ArrayType(IntegerType()), nullable=False),
StructField("sparse_values", ArrayType(FloatType()), nullable=False),
]
)
72 changes: 58 additions & 14 deletions src/test/python/test_qdrant_ingest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from pathlib import Path
from pyspark.sql import SparkSession
from pyspark.errors import IllegalArgumentException
import pytest

from .schema import schema
from .conftest import Qdrant

input_file_path = Path(__file__).with_name("users.json")


def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -20,7 +24,11 @@ def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession):


def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -36,7 +44,11 @@ def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession):
def test_upsert_multiple_named_dense_vectors(
qdrant: Qdrant, spark_session: SparkSession
):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -50,7 +62,11 @@ def test_upsert_multiple_named_dense_vectors(


def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -64,7 +80,11 @@ def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession):


def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -78,7 +98,11 @@ def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSess


def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkSession):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -96,7 +120,11 @@ def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkS
def test_upsert_sparse_unnamed_dense_vectors(
qdrant: Qdrant, spark_session: SparkSession
):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -114,7 +142,11 @@ def test_upsert_sparse_unnamed_dense_vectors(
def test_upsert_multiple_sparse_dense_vectors(
qdrant: Qdrant, spark_session: SparkSession
):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -131,7 +163,11 @@ def test_upsert_multiple_sparse_dense_vectors(

# Test an upsert without vectors. All the dataframe fields will be treated as payload
def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand All @@ -143,9 +179,17 @@ def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession):


def test_custom_id_field(qdrant: Qdrant, spark_session: SparkSession):
df = spark_session.read.option("multiline", "true").json(str(input_file_path))

df = spark_session.read.option("multiline", "true").json(str(input_file_path))
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)

df = (
spark_session.read.schema(schema)
.option("multiline", "true")
.json(str(input_file_path))
)
df.write.format("io.qdrant.spark.Qdrant").option(
"qdrant_url",
qdrant.url,
Expand Down

0 comments on commit 0c98219

Please sign in to comment.