From 0c982199cb3e3b41bb1ba965872fd01d7a9ff4b5 Mon Sep 17 00:00:00 2001 From: Anush Date: Sun, 10 Mar 2024 23:29:21 +0530 Subject: [PATCH] test: schema for pytest (#18) --- .../{upload-binaries.yml => fatjar.yml} | 2 +- pom.xml | 2 +- src/test/python/schema.py | 106 ++++++++++++++++++ src/test/python/test_qdrant_ingest.py | 72 +++++++++--- 4 files changed, 166 insertions(+), 16 deletions(-) rename .github/workflows/{upload-binaries.yml => fatjar.yml} (98%) create mode 100644 src/test/python/schema.py diff --git a/.github/workflows/upload-binaries.yml b/.github/workflows/fatjar.yml similarity index 98% rename from .github/workflows/upload-binaries.yml rename to .github/workflows/fatjar.yml index 4580635..7daa09e 100644 --- a/.github/workflows/upload-binaries.yml +++ b/.github/workflows/fatjar.yml @@ -3,7 +3,7 @@ name: Build and release JAR files on: release: types: - - created + - created jobs: upload-jar: diff --git a/pom.xml b/pom.xml index 02f76a6..4c9f1df 100644 --- a/pom.xml +++ b/pom.xml @@ -39,7 +39,7 @@ io.qdrant client - 1.7.2 + 1.8.0 com.google.guava diff --git a/src/test/python/schema.py b/src/test/python/schema.py new file mode 100644 index 0000000..35a02ac --- /dev/null +++ b/src/test/python/schema.py @@ -0,0 +1,106 @@ +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + IntegerType, + DoubleType, + ArrayType, + FloatType, +) + +hair_schema = StructType( + [ + StructField("color", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + ] +) + +coordinates_schema = StructType( + [ + StructField("lat", DoubleType(), nullable=True), + StructField("lng", DoubleType(), nullable=True), + ] +) + +address_schema = StructType( + [ + StructField("address", StringType(), nullable=True), + StructField("city", StringType(), nullable=True), + StructField("coordinates", coordinates_schema, nullable=True), + StructField("postalCode", StringType(), nullable=True), + StructField("state", StringType(), nullable=True), + ] +) + +bank_schema = StructType( + [ + StructField("cardExpire", StringType(), nullable=True), + StructField("cardNumber", StringType(), nullable=True), + StructField("cardType", StringType(), nullable=True), + StructField("currency", StringType(), nullable=True), + StructField("iban", StringType(), nullable=True), + ] +) + +company_address_schema = StructType( + [ + StructField("address", StringType(), nullable=True), + StructField("city", StringType(), nullable=True), + StructField("coordinates", coordinates_schema, nullable=True), + StructField("postalCode", StringType(), nullable=True), + StructField("state", StringType(), nullable=True), + ] +) + +company_schema = StructType( + [ + StructField("address", company_address_schema, nullable=True), + StructField("department", StringType(), nullable=True), + StructField("name", StringType(), nullable=True), + StructField("title", StringType(), nullable=True), + ] +) + +crypto_schema = StructType( + [ + StructField("coin", StringType(), nullable=True), + StructField("wallet", StringType(), nullable=True), + StructField("network", StringType(), nullable=True), + ] +) + +schema = StructType( + [ + StructField("id", IntegerType(), nullable=True), + StructField("firstName", StringType(), nullable=True), + StructField("lastName", StringType(), nullable=True), + StructField("maidenName", StringType(), nullable=True), + StructField("age", IntegerType(), nullable=True), + StructField("gender", StringType(), nullable=True), + StructField("email", StringType(), nullable=True), + StructField("phone", StringType(), nullable=True), + StructField("username", StringType(), nullable=True), + StructField("password", StringType(), nullable=True), + StructField("birthDate", StringType(), nullable=True), + StructField("image", StringType(), nullable=True), + StructField("bloodGroup", StringType(), nullable=True), + StructField("height", DoubleType(), nullable=True), + StructField("weight", DoubleType(), nullable=True), + StructField("eyeColor", StringType(), nullable=True), + StructField("hair", hair_schema, nullable=True), + StructField("domain", StringType(), nullable=True), + StructField("ip", StringType(), nullable=True), + StructField("address", address_schema, nullable=True), + StructField("macAddress", StringType(), nullable=True), + StructField("university", StringType(), nullable=True), + StructField("bank", bank_schema, nullable=True), + StructField("company", company_schema, nullable=True), + StructField("ein", StringType(), nullable=True), + StructField("ssn", StringType(), nullable=True), + StructField("userAgent", StringType(), nullable=True), + StructField("crypto", crypto_schema, nullable=True), + StructField("dense_vector", ArrayType(FloatType()), nullable=False), + StructField("sparse_indices", ArrayType(IntegerType()), nullable=False), + StructField("sparse_values", ArrayType(FloatType()), nullable=False), + ] +) diff --git a/src/test/python/test_qdrant_ingest.py b/src/test/python/test_qdrant_ingest.py index 9175c9a..8164005 100644 --- a/src/test/python/test_qdrant_ingest.py +++ b/src/test/python/test_qdrant_ingest.py @@ -1,14 +1,18 @@ from pathlib import Path from pyspark.sql import SparkSession -from pyspark.errors import IllegalArgumentException -import pytest + +from .schema import schema from .conftest import Qdrant input_file_path = Path(__file__).with_name("users.json") def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -20,7 +24,11 @@ def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession): def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -36,7 +44,11 @@ def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession): def test_upsert_multiple_named_dense_vectors( qdrant: Qdrant, spark_session: SparkSession ): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -50,7 +62,11 @@ def test_upsert_multiple_named_dense_vectors( def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -64,7 +80,11 @@ def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession): def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -78,7 +98,11 @@ def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSess def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkSession): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -96,7 +120,11 @@ def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkS def test_upsert_sparse_unnamed_dense_vectors( qdrant: Qdrant, spark_session: SparkSession ): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -114,7 +142,11 @@ def test_upsert_sparse_unnamed_dense_vectors( def test_upsert_multiple_sparse_dense_vectors( qdrant: Qdrant, spark_session: SparkSession ): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -131,7 +163,11 @@ def test_upsert_multiple_sparse_dense_vectors( # Test an upsert without vectors. All the dataframe fields will be treated as payload def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url, @@ -143,9 +179,17 @@ def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession): def test_custom_id_field(qdrant: Qdrant, spark_session: SparkSession): - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) - - df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) + + df = ( + spark_session.read.schema(schema) + .option("multiline", "true") + .json(str(input_file_path)) + ) df.write.format("io.qdrant.spark.Qdrant").option( "qdrant_url", qdrant.url,