Skip to content

Commit

Permalink
refactor: Further type validate (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 authored Jul 11, 2024
1 parent 15521a1 commit 9484a64
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/main/java/io/qdrant/spark/QdrantVectorHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Map;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
Expand Down Expand Up @@ -98,6 +99,12 @@ private static float[] extractFloatArray(InternalRow record, int fieldIndex, Dat
throw new IllegalArgumentException("Vector field must be of type ArrayType");
}

ArrayType arrayType = (ArrayType) dataType;

if (!arrayType.elementType().typeName().equalsIgnoreCase("float")) {
throw new IllegalArgumentException("Expected array elements to be of FloatType");
}

return record.getArray(fieldIndex).toFloatArray();
}

Expand All @@ -107,14 +114,26 @@ private static int[] extractIntArray(InternalRow record, int fieldIndex, DataTyp
throw new IllegalArgumentException("Vector field must be of type ArrayType");
}

ArrayType arrayType = (ArrayType) dataType;

if (!arrayType.elementType().typeName().equalsIgnoreCase("integer")) {
throw new IllegalArgumentException("Expected array elements to be of IntegerType");
}

return record.getArray(fieldIndex).toIntArray();
}

private static float[][] extractMultiVecArray(
InternalRow record, int fieldIndex, DataType dataType) {

if (!dataType.typeName().equalsIgnoreCase("array")) {
throw new IllegalArgumentException("Vector field must be of type ArrayType");
throw new IllegalArgumentException("Multi Vector field must be of type ArrayType");
}

ArrayType arrayType = (ArrayType) dataType;

if (!arrayType.elementType().typeName().equalsIgnoreCase("array")) {
throw new IllegalArgumentException("Multi Vector elements must be of type ArrayType");
}

ArrayData arrayData = record.getArray(fieldIndex);
Expand Down

0 comments on commit 9484a64

Please sign in to comment.