diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/IVFIndexBuilder.java b/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/IVFIndexBuilder.java index 20f2a086d0a..8db1ade2491 100644 --- a/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/IVFIndexBuilder.java +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/IVFIndexBuilder.java @@ -25,7 +25,7 @@ public class IVFIndexBuilder extends IndexBuilder { private int minVectorsPerPartition = -1; - IVFIndexBuilder() { } + IVFIndexBuilder() {} /** * Configures the target accuracy. @@ -35,9 +35,9 @@ public class IVFIndexBuilder extends IndexBuilder { * @throws IllegalArgumentException If the target accuracy not between 1 and 100. */ public IVFIndexBuilder targetAccuracy(int targetAccuracy) throws IllegalArgumentException { - ensureBetween(targetAccuracy, 0, 100, "targetAccuracy"); - this.targetAccuracy = targetAccuracy; - return this; + ensureBetween(targetAccuracy, 0, 100, "targetAccuracy"); + this.targetAccuracy = targetAccuracy; + return this; } /** @@ -47,9 +47,9 @@ public IVFIndexBuilder targetAccuracy(int targetAccuracy) throws IllegalArgument * @return This builder. */ public IVFIndexBuilder degreeOfParallelism(int degreeOfParallelism) { - ensureGreaterThanZero(degreeOfParallelism, "degreeOfParallelism"); - this.degreeOfParallelism = degreeOfParallelism; - return this; + ensureGreaterThanZero(degreeOfParallelism, "degreeOfParallelism"); + this.degreeOfParallelism = degreeOfParallelism; + return this; } /** @@ -65,9 +65,9 @@ public IVFIndexBuilder degreeOfParallelism(int degreeOfParallelism) { * 10000000, or if the vector type is not IVF. */ public IVFIndexBuilder neighborPartitions(int neighborPartitions) throws IllegalArgumentException { - ensureBetween(neighborPartitions, 1, 10000000, "neighborPartitions"); - this.neighborPartitions = neighborPartitions; - return this; + ensureBetween(neighborPartitions, 1, 10000000, "neighborPartitions"); + this.neighborPartitions = neighborPartitions; + return this; } /** @@ -88,9 +88,9 @@ public IVFIndexBuilder neighborPartitions(int neighborPartitions) throws Illegal * @throws IllegalArgumentException If the number of samples per partition is lower than 1. */ public IVFIndexBuilder samplePerPartition(int samplePerPartition) throws IllegalArgumentException { - ensureBetween(samplePerPartition, 1, Integer.MAX_VALUE, "samplePerPartition"); - this.samplePerPartition = samplePerPartition; - return this; + ensureBetween(samplePerPartition, 1, Integer.MAX_VALUE, "samplePerPartition"); + this.samplePerPartition = samplePerPartition; + return this; } /** @@ -108,66 +108,65 @@ public IVFIndexBuilder samplePerPartition(int samplePerPartition) throws Illegal * than 0. */ public IVFIndexBuilder minVectorsPerPartition(int minVectorsPerPartition) throws IllegalArgumentException { - ensureGreaterThanZero(minVectorsPerPartition, "minVectorsPerPartition"); - this.minVectorsPerPartition = minVectorsPerPartition; - return this; + ensureGreaterThanZero(minVectorsPerPartition, "minVectorsPerPartition"); + this.minVectorsPerPartition = minVectorsPerPartition; + return this; } - - /** - * {@inheritDoc} - */ - @Override + /** + * {@inheritDoc} + */ + @Override public Index build() { - return new Index(this); + return new Index(this); } - /** - * @inheritDoc - */ - @Override - String getCreateIndexStatement(EmbeddingTable embeddingTable) { - - return "CREATE VECTOR INDEX " + - (createOption == CreateOption.CREATE_IF_NOT_EXISTS ? "IF NOT EXISTS " : "") + - getIndexName(embeddingTable) + - " ON " + embeddingTable.name() + "( " + embeddingTable.embeddingColumn() + " ) " + - " ORGANIZATION NEIGHBOR PARTITIONS " + - " WITH DISTANCE COSINE " + - (targetAccuracy > 0 ? " WITH TARGET ACCURACY " + targetAccuracy + " " : "") + - (degreeOfParallelism >= 0 ? " PARALLEL " + degreeOfParallelism : "") + - getIndexParameters(); - } - - /** - * {@inheritDoc} - *

- * The index name id generated by concatenating "_VECTOR_INDEX" to the embedding table - * name. - *

- * @param embeddingTable The embedding table. - * @return The name of the index. - */ - @Override - String getIndexName(EmbeddingTable embeddingTable) { - if (indexName == null) { - indexName = buildIndexName(embeddingTable.name(), "_VECTOR_INDEX"); + /** + * @inheritDoc + */ + @Override + String getCreateIndexStatement(EmbeddingTable embeddingTable) { + + return "CREATE VECTOR INDEX " + (createOption == CreateOption.CREATE_IF_NOT_EXISTS ? "IF NOT EXISTS " : "") + + getIndexName(embeddingTable) + + " ON " + + embeddingTable.name() + "( " + embeddingTable.embeddingColumn() + " ) " + + " ORGANIZATION NEIGHBOR PARTITIONS " + + " WITH DISTANCE COSINE " + + (targetAccuracy > 0 ? " WITH TARGET ACCURACY " + targetAccuracy + " " : "") + + (degreeOfParallelism >= 0 ? " PARALLEL " + degreeOfParallelism : "") + + getIndexParameters(); } - return indexName; - } - - /** - * Generates the PARAMETERS clause of the vector index. Implementation depends on the type of vector index. - * @return A string containing the PARAMETERS clause of the index. - */ - String getIndexParameters() { - if (neighborPartitions == -1 && samplePerPartition == -1 && minVectorsPerPartition == -1) { - return " "; + + /** + * {@inheritDoc} + *

+ * The index name id generated by concatenating "_VECTOR_INDEX" to the embedding table + * name. + *

+ * @param embeddingTable The embedding table. + * @return The name of the index. + */ + @Override + String getIndexName(EmbeddingTable embeddingTable) { + if (indexName == null) { + indexName = buildIndexName(embeddingTable.name(), "_VECTOR_INDEX"); + } + return indexName; } - return "PARAMETERS ( TYPE IVF" + - (neighborPartitions != -1 ? ", NEIGHBOR PARTITIONS " + neighborPartitions + " " : "") + - (samplePerPartition != -1 ? ", SAMPLES_PER_PARTITION " + samplePerPartition + " " : "") + - (minVectorsPerPartition != -1 ? ", MIN_VECTORS_PER_PARTITION " + minVectorsPerPartition + " " : "") + ")"; - } -} + /** + * Generates the PARAMETERS clause of the vector index. Implementation depends on the type of vector index. + * @return A string containing the PARAMETERS clause of the index. + */ + String getIndexParameters() { + if (neighborPartitions == -1 && samplePerPartition == -1 && minVectorsPerPartition == -1) { + return " "; + } + return "PARAMETERS ( TYPE IVF" + + (neighborPartitions != -1 ? ", NEIGHBOR PARTITIONS " + neighborPartitions + " " : "") + + (samplePerPartition != -1 ? ", SAMPLES_PER_PARTITION " + samplePerPartition + " " : "") + + (minVectorsPerPartition != -1 ? ", MIN_VECTORS_PER_PARTITION " + minVectorsPerPartition + " " : "") + + ")"; + } +} diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/JSONIndexBuilder.java b/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/JSONIndexBuilder.java index 01476a8f067..9aa764b7c2d 100644 --- a/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/JSONIndexBuilder.java +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/JSONIndexBuilder.java @@ -1,14 +1,12 @@ package dev.langchain4j.store.embedding.oracle; -import oracle.jdbc.OracleType; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import java.util.ArrayList; import java.util.List; -import java.util.UUID; import java.util.stream.Collectors; - -import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; -import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import oracle.jdbc.OracleType; /** *

@@ -22,165 +20,166 @@ */ public class JSONIndexBuilder extends IndexBuilder { - /** - * Indicates whether the index is unique. - */ - private boolean isUnique; - - /** - * Indicates whether the index is a bitmap index. - */ - private boolean isBitmap; - - /** - * Create option for the index, by default create if not exists; - */ - private CreateOption createOption = CreateOption.CREATE_IF_NOT_EXISTS; - - /** - * List of index expressions of the index. An expression is added for - * each JSON key that is indexed. - */ - private final List indexExpressions = new ArrayList(); - - /** - * Use ASC or DESC to indicate whether the index should be created in ascending or - * descending order. Indexes on character data are created in ascending or descending - * order of the character values in the database character set. - */ - public enum Order { /** - * Create the index on ascending order. + * Indicates whether the index is unique. + */ + private boolean isUnique; + + /** + * Indicates whether the index is a bitmap index. + */ + private boolean isBitmap; + + /** + * Create option for the index, by default create if not exists; + */ + private CreateOption createOption = CreateOption.CREATE_IF_NOT_EXISTS; + + /** + * List of index expressions of the index. An expression is added for + * each JSON key that is indexed. + */ + private final List indexExpressions = new ArrayList(); + + /** + * Use ASC or DESC to indicate whether the index should be created in ascending or + * descending order. Indexes on character data are created in ascending or descending + * order of the character values in the database character set. + */ + public enum Order { + /** + * Create the index on ascending order. + */ + ASC, + /** + * Create the index on descending order. + */ + DESC + } + + JSONIndexBuilder() {} + + /** + * Specify UNIQUE to indicate that the value of the column (or columns) upon + * which the index is based must be unique. + * Note that you cannot specify both UNIQUE and BITMAP. + * + * @param isUnique True if the index should be UNIQUE otherwise false; + * @return This builder. */ - ASC, + public JSONIndexBuilder isUnique(boolean isUnique) { + this.isUnique = isUnique; + return this; + } + /** - * Create the index on descending order. + * Specify BITMAP to indicate that index is to be created with a bitmap for each + * distinct key, rather than indexing each row separately. + * + * @param isBitmap True if the index should be BITMAP otherwise false; + * @return This builder. */ - DESC - } - - JSONIndexBuilder() { } - - /** - * Specify UNIQUE to indicate that the value of the column (or columns) upon - * which the index is based must be unique. - * Note that you cannot specify both UNIQUE and BITMAP. - * - * @param isUnique True if the index should be UNIQUE otherwise false; - * @return This builder. - */ - public JSONIndexBuilder isUnique(boolean isUnique) { - this.isUnique = isUnique; - return this; - } - - /** - * Specify BITMAP to indicate that index is to be created with a bitmap for each - * distinct key, rather than indexing each row separately. - * - * @param isBitmap True if the index should be BITMAP otherwise false; - * @return This builder. - */ - public JSONIndexBuilder isBitmap(boolean isBitmap) { - this.isBitmap = isBitmap; - return this; - } - - /** - * Adds a column expression to the index expression that allows to index the - * value of a given key of the JSON column. - * - * @param key The key to index. - * @param keyType The java class of the metadata column. - * @param order The order the index should be created in. - * @return This builder. - * @throws IllegalArgumentException If the key is null or empty, if the sqlType is null or if the order is null - */ - public JSONIndexBuilder key(String key, Class keyType, Order order) { - ensureNotBlank(key, "key"); - ensureNotNull(keyType, "sqlType"); - ensureNotNull(order, "order"); - indexExpressions.add(new MetadataKey(key, keyType, order)); - return this; - } - - /** - * {@inheritDoc} - */ - @Override - public Index build() { - return new Index(this); - } - - /** - * {@inheritDoc} - */ - @Override - String getCreateIndexStatement(EmbeddingTable embeddingTable) { - return "CREATE " + - (isUnique ? " UNIQUE " : "") + - (isBitmap ? " BITMAP " : "") + - " INDEX " + - (createOption == CreateOption.CREATE_IF_NOT_EXISTS ? " IF NOT EXISTS " : "") + - getIndexName(embeddingTable) + - " ON " + embeddingTable.name() + - "(" + getIndexExpression(embeddingTable) + ")"; - } - - /** - * {@inheritDoc} - *

- * The index name id generated by concatenating "_METADATA_" and the indexed key - * names separated by an underscore character to the embedding table name. - *

- */ - @Override - String getIndexName(EmbeddingTable embeddingTable) { - if (indexName == null) { - indexName = buildIndexName( - embeddingTable.name(), - "_METADATA_" + this.indexExpressions - .stream() - .map(metadataKey -> metadataKey.getKey().toUpperCase()) - .collect(Collectors.joining("_"))); + public JSONIndexBuilder isBitmap(boolean isBitmap) { + this.isBitmap = isBitmap; + return this; } - return indexName; - } - - private String getIndexExpression(EmbeddingTable embeddingTable) { - return indexExpressions.stream().map(metadataKey -> { - OracleType oracleType = SQLFilters.toOracleType(metadataKey.keyType); - return embeddingTable.mapMetadataKey(metadataKey.key, oracleType) + " " + metadataKey.order; - }).collect(Collectors.joining(",")); - } - - /** - * Private class that represents the index expression of the index. It contains - * three members: the JSON key, the java data type if the key and the order in - * which the key should be indexed (ASC or DESC). - */ - private class MetadataKey { - private String key; - private Class keyType; - private Order order; - - public MetadataKey(String key, Class keyType, Order order) { - this.key = key; - this.keyType = keyType; - this.order = order; + + /** + * Adds a column expression to the index expression that allows to index the + * value of a given key of the JSON column. + * + * @param key The key to index. + * @param keyType The java class of the metadata column. + * @param order The order the index should be created in. + * @return This builder. + * @throws IllegalArgumentException If the key is null or empty, if the sqlType is null or if the order is null + */ + public JSONIndexBuilder key(String key, Class keyType, Order order) { + ensureNotBlank(key, "key"); + ensureNotNull(keyType, "sqlType"); + ensureNotNull(order, "order"); + indexExpressions.add(new MetadataKey(key, keyType, order)); + return this; } - public String getKey() { - return key; + /** + * {@inheritDoc} + */ + @Override + public Index build() { + return new Index(this); } - public Order getOrder() { - return order; + /** + * {@inheritDoc} + */ + @Override + String getCreateIndexStatement(EmbeddingTable embeddingTable) { + return "CREATE " + (isUnique ? " UNIQUE " : "") + + (isBitmap ? " BITMAP " : "") + + " INDEX " + + (createOption == CreateOption.CREATE_IF_NOT_EXISTS ? " IF NOT EXISTS " : "") + + getIndexName(embeddingTable) + + " ON " + + embeddingTable.name() + "(" + + getIndexExpression(embeddingTable) + ")"; } - public Class getKeyType() { - return keyType; + /** + * {@inheritDoc} + *

+ * The index name id generated by concatenating "_METADATA_" and the indexed key + * names separated by an underscore character to the embedding table name. + *

+ */ + @Override + String getIndexName(EmbeddingTable embeddingTable) { + if (indexName == null) { + indexName = buildIndexName( + embeddingTable.name(), + "_METADATA_" + + this.indexExpressions.stream() + .map(metadataKey -> metadataKey.getKey().toUpperCase()) + .collect(Collectors.joining("_"))); + } + return indexName; + } + + private String getIndexExpression(EmbeddingTable embeddingTable) { + return indexExpressions.stream() + .map(metadataKey -> { + OracleType oracleType = SQLFilters.toOracleType(metadataKey.keyType); + return embeddingTable.mapMetadataKey(metadataKey.key, oracleType) + " " + metadataKey.order; + }) + .collect(Collectors.joining(",")); } - } -} + /** + * Private class that represents the index expression of the index. It contains + * three members: the JSON key, the java data type if the key and the order in + * which the key should be indexed (ASC or DESC). + */ + private class MetadataKey { + private String key; + private Class keyType; + private Order order; + + public MetadataKey(String key, Class keyType, Order order) { + this.key = key; + this.keyType = keyType; + this.order = order; + } + + public String getKey() { + return key; + } + + public Order getOrder() { + return order; + } + + public Class getKeyType() { + return keyType; + } + } +} diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/SQLFilter.java b/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/SQLFilter.java index 106778e5762..a8d85cc1da8 100644 --- a/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/SQLFilter.java +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/SQLFilter.java @@ -1,11 +1,9 @@ package dev.langchain4j.store.embedding.oracle; import dev.langchain4j.store.embedding.filter.Filter; - import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; -import java.util.function.UnaryOperator; /** *

@@ -84,5 +82,4 @@ default String asWhereClause() { * @throws SQLException If one is thrown from the PreparedStatement. */ int setParameters(PreparedStatement preparedStatement, int parameterIndex) throws SQLException; - } diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/SQLFilters.java b/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/SQLFilters.java index ddd141c4289..b854376ebdc 100644 --- a/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/SQLFilters.java +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/store/embedding/oracle/SQLFilters.java @@ -6,8 +6,6 @@ import dev.langchain4j.store.embedding.filter.logical.And; import dev.langchain4j.store.embedding.filter.logical.Not; import dev.langchain4j.store.embedding.filter.logical.Or; -import oracle.jdbc.OracleType; - import java.io.StringReader; import java.sql.PreparedStatement; import java.sql.SQLException; @@ -15,6 +13,7 @@ import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.Stream; +import oracle.jdbc.OracleType; /** * A factory for {@link SQLFilter} implementations. The {@link #create(Filter, BiFunction)} creates a SQLFilter that @@ -40,46 +39,39 @@ private SQLFilters() {} * Map of {@link Filter} classes to functions which construct the equivalent {@link SQLFilter}. */ private static final Map, FilterConstructor> CONSTRUCTORS; + static { Map, FilterConstructor> map = new HashMap<>(); - map.put(IsEqualTo.class, (filter, keyMapper) -> - new SQLComparisonFilter((IsEqualTo) filter, keyMapper)); + map.put(IsEqualTo.class, (filter, keyMapper) -> new SQLComparisonFilter((IsEqualTo) filter, keyMapper)); - map.put(IsNotEqualTo.class, (filter, keyMapper) -> - new SQLComparisonFilter((IsNotEqualTo) filter, keyMapper)); + map.put(IsNotEqualTo.class, (filter, keyMapper) -> new SQLComparisonFilter((IsNotEqualTo) filter, keyMapper)); - map.put(IsGreaterThan.class, (filter, keyMapper) -> - new SQLComparisonFilter((IsGreaterThan) filter, keyMapper)); + map.put(IsGreaterThan.class, (filter, keyMapper) -> new SQLComparisonFilter((IsGreaterThan) filter, keyMapper)); - map.put(IsGreaterThanOrEqualTo.class, (filter, keyMapper) -> - new SQLComparisonFilter((IsGreaterThanOrEqualTo) filter, keyMapper)); + map.put( + IsGreaterThanOrEqualTo.class, + (filter, keyMapper) -> new SQLComparisonFilter((IsGreaterThanOrEqualTo) filter, keyMapper)); - map.put(IsLessThan.class, (filter, keyMapper) -> - new SQLComparisonFilter((IsLessThan) filter, keyMapper)); + map.put(IsLessThan.class, (filter, keyMapper) -> new SQLComparisonFilter((IsLessThan) filter, keyMapper)); - map.put(IsLessThanOrEqualTo.class, (filter, keyMapper) -> - new SQLComparisonFilter((IsLessThanOrEqualTo) filter, keyMapper)); + map.put( + IsLessThanOrEqualTo.class, + (filter, keyMapper) -> new SQLComparisonFilter((IsLessThanOrEqualTo) filter, keyMapper)); - map.put(IsIn.class, (filter, keyMapper) -> - SQLInFilter.create((IsIn) filter, keyMapper)); + map.put(IsIn.class, (filter, keyMapper) -> SQLInFilter.create((IsIn) filter, keyMapper)); - map.put(IsNotIn.class, (filter, keyMapper) -> - SQLInFilter.create((IsNotIn) filter, keyMapper)); + map.put(IsNotIn.class, (filter, keyMapper) -> SQLInFilter.create((IsNotIn) filter, keyMapper)); - map.put(And.class, (filter, keyMapper) -> - new SQLLogicalFilter((And) filter, keyMapper)); + map.put(And.class, (filter, keyMapper) -> new SQLLogicalFilter((And) filter, keyMapper)); - map.put(Or.class, (filter, keyMapper) -> - new SQLLogicalFilter((Or) filter, keyMapper)); + map.put(Or.class, (filter, keyMapper) -> new SQLLogicalFilter((Or) filter, keyMapper)); - map.put(Not.class, (filter, keyMapper) -> - new SQLNot((Not)filter, keyMapper)); + map.put(Not.class, (filter, keyMapper) -> new SQLNot((Not) filter, keyMapper)); CONSTRUCTORS = Collections.unmodifiableMap(map); } - /** *

* Returns a SQL filter that evaluates to the same result as a Filter. @@ -101,14 +93,12 @@ private SQLFilters() {} * @throws IllegalArgumentException If the class of the Filter is not recognized. */ static SQLFilter create(Filter filter, BiFunction keyMapper) { - if (filter == null) - return EMPTY; + if (filter == null) return EMPTY; Class filterClass = filter.getClass(); FilterConstructor constructor = CONSTRUCTORS.get(filterClass); - if (constructor == null) - throw new IllegalArgumentException("Unrecognized Filter class: " + filterClass); + if (constructor == null) throw new IllegalArgumentException("Unrecognized Filter class: " + filterClass); return constructor.construct(filter, keyMapper); } @@ -176,7 +166,8 @@ private static class SQLComparisonFilter implements SQLFilter { this(isGreaterThan.key(), keyMapper, ">", isGreaterThan.comparisonValue(), false); } - SQLComparisonFilter(IsGreaterThanOrEqualTo isGreaterThanOrEqualTo, BiFunction keyMapper) { + SQLComparisonFilter( + IsGreaterThanOrEqualTo isGreaterThanOrEqualTo, BiFunction keyMapper) { this(isGreaterThanOrEqualTo.key(), keyMapper, ">=", isGreaterThanOrEqualTo.comparisonValue(), false); } @@ -221,7 +212,10 @@ private static class SQLComparisonFilter implements SQLFilter { * @param isNullTrue Result of the filter when the metadata does not contain the key. */ private SQLComparisonFilter( - String key, BiFunction keyMapper, String operator, T comparisonValue, + String key, + BiFunction keyMapper, + String operator, + T comparisonValue, boolean isNullTrue) { this.sqlType = toOracleType(comparisonValue); @@ -230,11 +224,10 @@ private SQLComparisonFilter( // DBMS_LOB.COMPARE must be used for a comparison of CLOB values. Comparison operators like "=" or "<" // cannot have a CLOB operand. The COMPARE function is similar to String.compareTo(String), returning // -1, 0, or 1 for less-than, equal-to, and greater-than, respectively. - sql = "NVL(" + - "DBMS_LOB.COMPARE(" + keyMapper.apply(key, sqlType) + ", ?) " + operator + " 0, " + sql = "NVL(" + "DBMS_LOB.COMPARE(" + + keyMapper.apply(key, sqlType) + ", ?) " + operator + " 0, " + isNullTrue + ")"; - } - else { + } else { sql = "NVL(" + keyMapper.apply(key, sqlType) + " " + operator + " ?, " + isNullTrue + ")"; } @@ -291,7 +284,8 @@ private static class SQLLogicalFilter implements SQLFilter { this(or.left(), "OR", or.right(), keyMapper); } - private SQLLogicalFilter(Filter left, String operator, Filter right, BiFunction keyMapper) { + private SQLLogicalFilter( + Filter left, String operator, Filter right, BiFunction keyMapper) { this(create(left, keyMapper), operator, create(right, keyMapper)); } @@ -407,13 +401,13 @@ static SQLFilter create(IsNotIn isNotIn, BiFunction } static SQLFilter create( - String key, BiFunction keyMapper, boolean isIn, + String key, + BiFunction keyMapper, + boolean isIn, Collection comparisonValues) { Set sqlTypes = - comparisonValues.stream() - .map(SQLFilters::toOracleType) - .collect(Collectors.toSet()); + comparisonValues.stream().map(SQLFilters::toOracleType).collect(Collectors.toSet()); Iterator sqlTypeIterator = sqlTypes.iterator(); OracleType sqlType = sqlTypes.iterator().next(); @@ -424,11 +418,10 @@ static SQLFilter create( } // Replicate IN and NOT IN conditions as a sequence of OR conditions: "key = value0 OR key = value1 OR ..." - SQLFilter orFilter = - comparisonValues.stream() - .map(object -> new SQLComparisonFilter(key, keyMapper, "=", object, false)) - .reduce((left, right) -> new SQLLogicalFilter(left, "OR", right)) - .orElse(EMPTY); + SQLFilter orFilter = comparisonValues.stream() + .map(object -> new SQLComparisonFilter(key, keyMapper, "=", object, false)) + .reduce((left, right) -> new SQLLogicalFilter(left, "OR", right)) + .orElse(EMPTY); return isIn ? orFilter : new SQLNot(orFilter); } @@ -455,13 +448,14 @@ static SQLFilter create( * @param comparisonValues Set of values to search within. Not null. Not empty. */ private SQLInFilter( - String key, BiFunction keyMapper, boolean isIn, - Collection comparisonValues, OracleType sqlType) { + String key, + BiFunction keyMapper, + boolean isIn, + Collection comparisonValues, + OracleType sqlType) { this.sqlType = sqlType; this.sql = "NVL(" + keyMapper.apply(key, sqlType) + (isIn ? " IN " : " NOT IN ") + "(" - + Stream.generate(() -> "?") - .limit(comparisonValues.size()) - .collect(Collectors.joining(", ")) + + Stream.generate(() -> "?").limit(comparisonValues.size()).collect(Collectors.joining(", ")) + "), " + !isIn + ")"; // <-- 2nd argument to NVL this.comparisonValues = comparisonValues; @@ -527,7 +521,7 @@ private static void setObject( Object jdbcObject = toJdbcObject(object); if (jdbcObject instanceof String && sqlType == OracleType.CLOB) { - String string = (String)jdbcObject; + String string = (String) jdbcObject; int length = string.length(); // Convert the String into a VARCHAR if the length is small enough. Oracle Database supports an implicit @@ -536,13 +530,11 @@ private static void setObject( // implementation of setCharacterStream). if (length < MAX_VARCHAR_LENGTH) { preparedStatement.setString(parameterIndex, (String) jdbcObject); - } - else { + } else { // Oracle JDBC converts a Reader into CLOB if setCharacterStream is called without a length argument. preparedStatement.setCharacterStream(parameterIndex, new StringReader(string)); } - } - else { + } else { preparedStatement.setObject(parameterIndex, jdbcObject, sqlType); } } @@ -560,8 +552,7 @@ private static void setObject( private static Object toJdbcObject(Object object) { if (object instanceof UUID) { return object.toString(); - } - else { + } else { return object; } } @@ -586,27 +577,22 @@ static OracleType toOracleType(Object object) { if (object instanceof Number) { if (object instanceof Float) { return OracleType.BINARY_FLOAT; - } - else if (object instanceof Double) { + } else if (object instanceof Double) { return OracleType.BINARY_DOUBLE; - } - else if (object instanceof Integer || object instanceof Long) { + } else if (object instanceof Integer || object instanceof Long) { // NUMBER is an integer with up to 38 decimal digits. It can represent any value of an Integer or Long. return OracleType.NUMBER; - } - else { + } else { // May need to add more branches above, if Metadata supports new object classes. throw new IllegalArgumentException("Unexpected object class: " + object.getClass()); } - } - else if (object instanceof String) { + } else if (object instanceof String) { // This String will be compared to another character value, and the length of that other character value is // not known. It cannot be assumed that the other character value's length is small enough to be a VARCHAR. // For this reason, the two character values should be compared as CLOBs. return OracleType.CLOB; - } - else { + } else { // Compare null, UUID, and any other object that Metadata supports in the future as VARCHAR objects. // It is assumed that the getOsonFromMetadata object method in OracleEmbeddingStore will convert these // objects to String. If the String length is 4k or less, then a VARCHAR can store the information of @@ -614,5 +600,4 @@ else if (object instanceof String) { return OracleType.VARCHAR2; } } - } diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/CommonTestOperations.java b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/CommonTestOperations.java index a474b08ccf6..b5170c2543b 100644 --- a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/CommonTestOperations.java +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/CommonTestOperations.java @@ -1,20 +1,14 @@ package dev.langchain4j.store.embedding.oracle; +import static org.assertj.core.api.Assertions.assertThat; + import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingStore; -import oracle.jdbc.OracleConnection; -import oracle.sql.CHAR; -import oracle.sql.CharacterSet; -import oracle.ucp.jdbc.PoolDataSource; -import oracle.ucp.jdbc.PoolDataSourceFactory; -import org.testcontainers.oracle.OracleContainer; - -import javax.sql.DataSource; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -24,8 +18,13 @@ import java.util.List; import java.util.Random; import java.util.logging.Logger; - -import static org.assertj.core.api.Assertions.assertThat; +import javax.sql.DataSource; +import oracle.jdbc.OracleConnection; +import oracle.sql.CHAR; +import oracle.sql.CharacterSet; +import oracle.ucp.jdbc.PoolDataSource; +import oracle.ucp.jdbc.PoolDataSourceFactory; +import org.testcontainers.oracle.OracleContainer; /** * A collection of operations which are shared by tests in this package. @@ -45,8 +44,9 @@ final class CommonTestOperations { * Seed for random numbers. When a test fails, "-Ddev.langchain4j.store.embedding.oracle.SEED=..." can be used to * re-execute it with the same random numbers. */ - private static final long SEED = Long.getLong( - "dev.langchain4j.store.embedding.oracle.SEED", System.currentTimeMillis()); + private static final long SEED = + Long.getLong("dev.langchain4j.store.embedding.oracle.SEED", System.currentTimeMillis()); + static { Logger.getLogger(CommonTestOperations.class.getName()) .info("dev.langchain4j.store.embedding.oracle.SEED=" + SEED); @@ -71,8 +71,7 @@ private CommonTestOperations() {} if (urlFromEnv == null) { // The Ryuk component is relied upon to stop this container. - OracleContainer oracleContainer = - new OracleContainer(ORACLE_IMAGE_NAME) + OracleContainer oracleContainer = new OracleContainer(ORACLE_IMAGE_NAME) .withStartupTimeout(Duration.ofSeconds(600)) .withConnectTimeoutSeconds(600) .withDatabaseName("pdb1") @@ -80,18 +79,25 @@ private CommonTestOperations() {} .withPassword("testpwd"); oracleContainer.start(); - initDataSource(DATA_SOURCE, - oracleContainer.getJdbcUrl(), oracleContainer.getUsername(), oracleContainer.getPassword()); - initDataSource(SYSDBA_DATA_SOURCE, - oracleContainer.getJdbcUrl(), "sys", oracleContainer.getPassword()); + initDataSource( + DATA_SOURCE, + oracleContainer.getJdbcUrl(), + oracleContainer.getUsername(), + oracleContainer.getPassword()); + initDataSource(SYSDBA_DATA_SOURCE, oracleContainer.getJdbcUrl(), "sys", oracleContainer.getPassword()); } else { - initDataSource(DATA_SOURCE, - urlFromEnv, System.getenv("ORACLE_JDBC_USER"), System.getenv("ORACLE_JDBC_PASSWORD")); - initDataSource(SYSDBA_DATA_SOURCE, - urlFromEnv, System.getenv("ORACLE_JDBC_USER"), System.getenv("ORACLE_JDBC_PASSWORD")); + initDataSource( + DATA_SOURCE, + urlFromEnv, + System.getenv("ORACLE_JDBC_USER"), + System.getenv("ORACLE_JDBC_PASSWORD")); + initDataSource( + SYSDBA_DATA_SOURCE, + urlFromEnv, + System.getenv("ORACLE_JDBC_USER"), + System.getenv("ORACLE_JDBC_PASSWORD")); } - SYSDBA_DATA_SOURCE.setConnectionProperty(OracleConnection.CONNECTION_PROPERTY_INTERNAL_LOGON, - "SYSDBA"); + SYSDBA_DATA_SOURCE.setConnectionProperty(OracleConnection.CONNECTION_PROPERTY_INTERNAL_LOGON, "SYSDBA"); } catch (SQLException sqlException) { throw new AssertionError(sqlException); @@ -104,11 +110,9 @@ static void initDataSource(PoolDataSource dataSource, String url, String usernam dataSource.setURL(url); dataSource.setUser(username); dataSource.setPassword(password); - } catch ( - SQLException sqlException) { + } catch (SQLException sqlException) { throw new AssertionError(sqlException); } - } static EmbeddingModel getEmbeddingModel() { @@ -119,7 +123,9 @@ static DataSource getDataSource() { return DATA_SOURCE; } - static DataSource getSysDBADataSource() { return SYSDBA_DATA_SOURCE; } + static DataSource getSysDBADataSource() { + return SYSDBA_DATA_SOURCE; + } /** * Returns an embedding store configured to use a table with the common {@link #TABLE_NAME}. Any existing table @@ -165,7 +171,7 @@ static void dropTable() throws SQLException { */ static void dropTable(String tableName) throws SQLException { try (Connection connection = DATA_SOURCE.getConnection(); - Statement statement = connection.createStatement()) { + Statement statement = connection.createStatement()) { statement.addBatch("DROP INDEX IF EXISTS " + tableName + "_EMBEDDING_INDEX"); statement.addBatch("DROP TABLE IF EXISTS " + tableName); statement.executeBatch(); @@ -178,8 +184,8 @@ static void dropTable(String tableName) throws SQLException { */ static CharacterSet getCharacterSet() throws SQLException { try (Connection connection = CommonTestOperations.getDataSource().getConnection(); - Statement statement = connection.createStatement(); - ResultSet resultSet = statement.executeQuery("SELECT 'c' FROM sys.dual")) { + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery("SELECT 'c' FROM sys.dual")) { resultSet.next(); return resultSet.getObject(1, CHAR.class).getCharacterSet(); } @@ -194,8 +200,7 @@ static CharacterSet getCharacterSet() throws SQLException { static float[] randomFloats(int length) { float[] floats = new float[length]; - for (int i = 0; i < floats.length; i++) - floats[i] = RANDOM.nextFloat(); + for (int i = 0; i < floats.length; i++) floats[i] = RANDOM.nextFloat(); return floats; } @@ -214,8 +219,7 @@ static void verifySearch(EmbeddingStore embeddingStore) { float[] vector1 = vector0.clone(); // Only higher indexes are increased in order to effect the cosine angle, and not just magnitude - for (int i = 0; i < vector1.length / 2; i++) - vector1[i] += 0.1f; + for (int i = 0; i < vector1.length / 2; i++) vector1[i] += 0.1f; List embeddings = new ArrayList<>(2); embeddings.add(Embedding.from(vector0)); @@ -231,9 +235,7 @@ static void verifySearch(EmbeddingStore embeddingStore) { // Verify the first vector is matched EmbeddingMatch match = - embeddingStore.search(request) - .matches() - .get(0); + embeddingStore.search(request).matches().get(0); assertThat(match.embeddingId()).isEqualTo(ids.get(1)); assertThat(match.embedding().vector()).containsExactly(vector1); } diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/MetadataIndexStoreWithFilteringIT.java b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/MetadataIndexStoreWithFilteringIT.java index 0b3aa31eefc..1df14ba2a86 100644 --- a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/MetadataIndexStoreWithFilteringIT.java +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/MetadataIndexStoreWithFilteringIT.java @@ -5,19 +5,20 @@ public class MetadataIndexStoreWithFilteringIT extends OracleEmbeddingStoreWithFilteringIT { - private final OracleEmbeddingStore embeddingStore = CommonTestOperations.newEmbeddingStoreBuilder() - .index( - Index.jsonIndexBuilder() - .createOption(CreateOption.CREATE_OR_REPLACE) - .key("name", String.class, JSONIndexBuilder.Order.ASC) - .build(), - Index.jsonIndexBuilder() - .createOption(CreateOption.CREATE_OR_REPLACE) - .key("age", Float.class, JSONIndexBuilder.Order.ASC) - .build()) - .build(); - @Override - protected EmbeddingStore embeddingStore() { - return embeddingStore; - } + private final OracleEmbeddingStore embeddingStore = CommonTestOperations.newEmbeddingStoreBuilder() + .index( + Index.jsonIndexBuilder() + .createOption(CreateOption.CREATE_OR_REPLACE) + .key("name", String.class, JSONIndexBuilder.Order.ASC) + .build(), + Index.jsonIndexBuilder() + .createOption(CreateOption.CREATE_OR_REPLACE) + .key("age", Float.class, JSONIndexBuilder.Order.ASC) + .build()) + .build(); + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } } diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/MetadataIndexStoreWithRemovalIT.java b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/MetadataIndexStoreWithRemovalIT.java index 8e15f68b2a4..2f5f6602898 100644 --- a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/MetadataIndexStoreWithRemovalIT.java +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/MetadataIndexStoreWithRemovalIT.java @@ -4,19 +4,20 @@ import dev.langchain4j.store.embedding.EmbeddingStore; public class MetadataIndexStoreWithRemovalIT extends OracleEmbeddingStoreWithRemovalIT { - private final OracleEmbeddingStore embeddingStore = CommonTestOperations.newEmbeddingStoreBuilder() - .index( - Index.jsonIndexBuilder() - .createOption(CreateOption.CREATE_OR_REPLACE) - .key("name", String.class, JSONIndexBuilder.Order.ASC) - .build(), - Index.jsonIndexBuilder() - .createOption(CreateOption.CREATE_OR_REPLACE) - .key("age", Float.class, JSONIndexBuilder.Order.ASC) - .build()) - .build(); - @Override - protected EmbeddingStore embeddingStore() { - return embeddingStore; - } + private final OracleEmbeddingStore embeddingStore = CommonTestOperations.newEmbeddingStoreBuilder() + .index( + Index.jsonIndexBuilder() + .createOption(CreateOption.CREATE_OR_REPLACE) + .key("name", String.class, JSONIndexBuilder.Order.ASC) + .build(), + Index.jsonIndexBuilder() + .createOption(CreateOption.CREATE_OR_REPLACE) + .key("age", Float.class, JSONIndexBuilder.Order.ASC) + .build()) + .build(); + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } } diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/OracleEmbeddingStoreWithFilteringIT.java b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/OracleEmbeddingStoreWithFilteringIT.java index 5a8787c3868..1657a4e194b 100644 --- a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/OracleEmbeddingStoreWithFilteringIT.java +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/OracleEmbeddingStoreWithFilteringIT.java @@ -1,5 +1,7 @@ package dev.langchain4j.store.embedding.oracle; +import static org.assertj.core.api.Assertions.assertThat; + import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; @@ -14,12 +16,6 @@ import dev.langchain4j.store.embedding.filter.logical.And; import dev.langchain4j.store.embedding.filter.logical.Not; import dev.langchain4j.store.embedding.filter.logical.Or; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; - import java.sql.SQLException; import java.util.Arrays; import java.util.Collection; @@ -27,8 +23,11 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; - -import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; public class OracleEmbeddingStoreWithFilteringIT extends EmbeddingStoreWithFilteringIT { @@ -62,9 +61,7 @@ public void testClobTextSegment() { OracleEmbeddingStore oracleEmbeddingStore = CommonTestOperations.newEmbeddingStore(); Embedding embedding0 = TestData.randomEmbedding(); - TextSegment textSegment0 = TextSegment.from( - "0 " + STRING_32K, - Metadata.from("a", "A " + STRING_32K)); + TextSegment textSegment0 = TextSegment.from("0 " + STRING_32K, Metadata.from("a", "A " + STRING_32K)); String id0 = oracleEmbeddingStore.add(embedding0, textSegment0); float[] vector1 = embedding0.vector().clone(); @@ -73,24 +70,19 @@ public void testClobTextSegment() { vector2[0] += 20.0f; Embedding embedding1 = new Embedding(vector1); Embedding embedding2 = new Embedding(vector2); - TextSegment textSegment1 = TextSegment.from( - "1 " + STRING_32K, - Metadata.from("b", "B " + STRING_32K)); - TextSegment textSegment2 = TextSegment.from( - "2 " + STRING_32K, - Metadata.from("c", "C " + STRING_32K)); + TextSegment textSegment1 = TextSegment.from("1 " + STRING_32K, Metadata.from("b", "B " + STRING_32K)); + TextSegment textSegment2 = TextSegment.from("2 " + STRING_32K, Metadata.from("c", "C " + STRING_32K)); List ids = oracleEmbeddingStore.addAll( - Arrays.asList(embedding1, embedding2), - Arrays.asList(textSegment1, textSegment2)); + Arrays.asList(embedding1, embedding2), Arrays.asList(textSegment1, textSegment2)); // Round 1: No filter. Just return CLOB valued text segments and metadata. - List> matches0 = - oracleEmbeddingStore.search(EmbeddingSearchRequest.builder() - .queryEmbedding(embedding0) - .minScore(0d) - .maxResults(3) - .build()) - .matches(); + List> matches0 = oracleEmbeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding0) + .minScore(0d) + .maxResults(3) + .build()) + .matches(); assertThat(matches0.size()).as(matches0.toString()).isEqualTo(3); assertThat(matches0.get(0).embeddingId()).isEqualTo(id0); @@ -106,30 +98,32 @@ public void testClobTextSegment() { // Round 2: IsEqualTo on a substring of a metadata value. The substring length is 4000. This has SQLFilter // use a VARCHAR comparison. The JSON_VALUE function will have to handle a metadata value which is too // large for a VARCHAR. If it makes defective use of TRUNCATE, this test will fail. - List> matches1 = - oracleEmbeddingStore.search(EmbeddingSearchRequest.builder() - .queryEmbedding(embedding0) - .minScore(0d) - .maxResults(3) - .filter(MetadataFilterBuilder.metadataKey("a") - .isEqualTo(textSegment0.metadata().getString("a").substring(0, 4000))) - .build()) - .matches(); + List> matches1 = oracleEmbeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding0) + .minScore(0d) + .maxResults(3) + .filter(MetadataFilterBuilder.metadataKey("a") + .isEqualTo( + textSegment0.metadata().getString("a").substring(0, 4000))) + .build()) + .matches(); assertThat(matches1.isEmpty()).as(matches1.toString()).isTrue(); // Round 3: IsGreaterThan on a substring of a metadata value. "CC".compareTo("C") returns a positive value, so // the filter should match. "BB".compareTo("C") returns a negative value, so the filter should not match the // other two text segments. If the JSON_VALUE function uses TRUNCATE, it will be equivalent to // "C".compareTo("C"), which returns 0, and the grater than comparison will evaluate to FALSE, which is wrong. - List> matches2 = - oracleEmbeddingStore.search(EmbeddingSearchRequest.builder() - .queryEmbedding(embedding0) - .minScore(0d) - .maxResults(3) - .filter(MetadataFilterBuilder.metadataKey("a").isGreaterThan( + List> matches2 = oracleEmbeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding0) + .minScore(0d) + .maxResults(3) + .filter(MetadataFilterBuilder.metadataKey("a") + .isGreaterThan( textSegment0.metadata().getString("a").substring(0, 4000))) - .build()) - .matches(); + .build()) + .matches(); assertThat(matches2.size()).as(matches2.toString()).isEqualTo(1); assertThat(matches0.get(0).embeddingId()).isEqualTo(id0); @@ -139,9 +133,8 @@ public void testClobTextSegment() { @ParameterizedTest @MethodSource("should_filter_by_metadata") - protected void should_filter_by_metadata(Filter metadataFilter, - List matchingMetadatas, - List notMatchingMetadatas) { + protected void should_filter_by_metadata( + Filter metadataFilter, List matchingMetadatas, List notMatchingMetadatas) { super.should_filter_by_metadata(metadataFilter, matchingMetadatas, notMatchingMetadatas); } @@ -158,14 +151,13 @@ protected void should_filter_by_metadata(Filter metadataFilter, @SuppressWarnings("unchecked") protected static Stream should_filter_by_metadata() { return Stream.concat( - EmbeddingStoreWithFilteringIT.should_filter_by_metadata(), - EmbeddingStoreWithFilteringIT.should_filter_by_metadata() - .map(Arguments::get) - .map(argumentsArray -> Arguments.of( - toClobFilter((Filter) argumentsArray[0]), - toClobMetadata((List) argumentsArray[1]), - toClobMetadata((List) argumentsArray[2])))); - + EmbeddingStoreWithFilteringIT.should_filter_by_metadata(), + EmbeddingStoreWithFilteringIT.should_filter_by_metadata() + .map(Arguments::get) + .map(argumentsArray -> Arguments.of( + toClobFilter((Filter) argumentsArray[0]), + toClobMetadata((List) argumentsArray[1]), + toClobMetadata((List) argumentsArray[2])))); } /** @@ -174,7 +166,8 @@ protected static Stream should_filter_by_metadata() { * parameter is set to "EXTENDED". Otherwise, the maximum size is 4000. Either way, this length will force a CLOB * conversion. */ - private static final String STRING_32K = Stream.generate(() -> "A").limit(32768).collect(Collectors.joining()); + private static final String STRING_32K = + Stream.generate(() -> "A").limit(32768).collect(Collectors.joining()); /** * Converts a Filter into one which uses a long String value. @@ -185,65 +178,40 @@ protected static Stream should_filter_by_metadata() { private static Filter toClobFilter(Filter filter) { if (filter instanceof IsEqualTo) { IsEqualTo isEqualTo = ((IsEqualTo) filter); - return new IsEqualTo( - isEqualTo.key(), - toClobValue(isEqualTo.comparisonValue())); - } - else if (filter instanceof IsNotEqualTo) { - IsNotEqualTo isNotEqualTo = (IsNotEqualTo)filter; - return new IsNotEqualTo( - isNotEqualTo.key(), - toClobValue(isNotEqualTo.comparisonValue())); - } - else if (filter instanceof IsGreaterThan) { - IsGreaterThan isGreaterThan = (IsGreaterThan)filter; - return new IsGreaterThan( - isGreaterThan.key(), - toClobValue(isGreaterThan.comparisonValue())); - } - else if (filter instanceof IsGreaterThanOrEqualTo) { - IsGreaterThanOrEqualTo isGreaterThanOrEqualTo = (IsGreaterThanOrEqualTo)filter; + return new IsEqualTo(isEqualTo.key(), toClobValue(isEqualTo.comparisonValue())); + } else if (filter instanceof IsNotEqualTo) { + IsNotEqualTo isNotEqualTo = (IsNotEqualTo) filter; + return new IsNotEqualTo(isNotEqualTo.key(), toClobValue(isNotEqualTo.comparisonValue())); + } else if (filter instanceof IsGreaterThan) { + IsGreaterThan isGreaterThan = (IsGreaterThan) filter; + return new IsGreaterThan(isGreaterThan.key(), toClobValue(isGreaterThan.comparisonValue())); + } else if (filter instanceof IsGreaterThanOrEqualTo) { + IsGreaterThanOrEqualTo isGreaterThanOrEqualTo = (IsGreaterThanOrEqualTo) filter; return new IsGreaterThanOrEqualTo( - isGreaterThanOrEqualTo.key(), - toClobValue(isGreaterThanOrEqualTo.comparisonValue())); - } - else if (filter instanceof IsLessThan) { - IsLessThan isLessThan = (IsLessThan)filter; - return new IsLessThan( - isLessThan.key(), - toClobValue(isLessThan.comparisonValue())); - } - else if (filter instanceof IsLessThanOrEqualTo) { - IsLessThanOrEqualTo isLessThanOrEqualTo = (IsLessThanOrEqualTo)filter; + isGreaterThanOrEqualTo.key(), toClobValue(isGreaterThanOrEqualTo.comparisonValue())); + } else if (filter instanceof IsLessThan) { + IsLessThan isLessThan = (IsLessThan) filter; + return new IsLessThan(isLessThan.key(), toClobValue(isLessThan.comparisonValue())); + } else if (filter instanceof IsLessThanOrEqualTo) { + IsLessThanOrEqualTo isLessThanOrEqualTo = (IsLessThanOrEqualTo) filter; return new IsLessThanOrEqualTo( - isLessThanOrEqualTo.key(), - toClobValue(isLessThanOrEqualTo.comparisonValue())); - } - else if (filter instanceof IsIn) { - IsIn isIn = (IsIn)filter; - return new IsIn( - isIn.key(), - toClobValue(isIn.comparisonValues())); - } - else if (filter instanceof IsNotIn) { - IsNotIn isNotIn = (IsNotIn)filter; - return new IsNotIn( - isNotIn.key(), - toClobValue(isNotIn.comparisonValues())); - } - else if (filter instanceof And) { - And and = (And)filter; + isLessThanOrEqualTo.key(), toClobValue(isLessThanOrEqualTo.comparisonValue())); + } else if (filter instanceof IsIn) { + IsIn isIn = (IsIn) filter; + return new IsIn(isIn.key(), toClobValue(isIn.comparisonValues())); + } else if (filter instanceof IsNotIn) { + IsNotIn isNotIn = (IsNotIn) filter; + return new IsNotIn(isNotIn.key(), toClobValue(isNotIn.comparisonValues())); + } else if (filter instanceof And) { + And and = (And) filter; return new And(toClobFilter(and.left()), toClobFilter(and.right())); - } - else if (filter instanceof Or) { - Or or = (Or)filter; + } else if (filter instanceof Or) { + Or or = (Or) filter; return new Or(toClobFilter(or.left()), toClobFilter(or.right())); - } - else if (filter instanceof Not) { - Not not = (Not)filter; + } else if (filter instanceof Not) { + Not not = (Not) filter; return new Not(toClobFilter(not.expression())); - } - else { + } else { throw new RuntimeException("Need to add a case for: " + filter.getClass()); } } @@ -292,14 +260,12 @@ private static Collection toClobValue(Collection values) { */ @SuppressWarnings("unchecked") private static T toClobValue(T value) { - if (!(value instanceof String)) - return value; + if (!(value instanceof String)) return value; String stringValue = ((String) value); - if (stringValue.length() >= STRING_32K.length()) - return (T)stringValue; + if (stringValue.length() >= STRING_32K.length()) return (T) stringValue; - return (T)(stringValue + STRING_32K); + return (T) (stringValue + STRING_32K); } } diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/SQLFilterIT.java b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/SQLFilterIT.java index 84ed80e6ef4..9698f794c6f 100644 --- a/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/SQLFilterIT.java +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/store/embedding/oracle/SQLFilterIT.java @@ -1,23 +1,21 @@ package dev.langchain4j.store.embedding.oracle; +import static org.assertj.core.api.Assertions.assertThat; + import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; import dev.langchain4j.store.embedding.filter.comparison.*; -import oracle.jdbc.OracleType; -import org.junit.jupiter.api.Test; - import java.sql.*; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertTrue; +import oracle.jdbc.OracleType; +import org.junit.jupiter.api.Test; /** * Verifies that {@link SQLFilter} behaves as specified in its JavaDoc. The @@ -44,8 +42,7 @@ public void testNonUniformCollection() { // Comparison values of different object types. Note that it contains value matching the "x" value of // textSegment0's metadata. - Collection comparisonValues = - Stream.of( + Collection comparisonValues = Stream.of( Integer.MIN_VALUE, textSegment0.metadata().getLong("x"), Float.MIN_VALUE, @@ -58,31 +55,30 @@ public void testNonUniformCollection() { OracleEmbeddingStore oracleEmbeddingStore = CommonTestOperations.newEmbeddingStore(); List ids = oracleEmbeddingStore.addAll( - Arrays.asList(embedding0, embedding1), - Arrays.asList(textSegment0, textSegment1)); - - List> matches0 = - oracleEmbeddingStore.search(EmbeddingSearchRequest.builder() - .queryEmbedding(embedding0) - .minScore(0d) - .maxResults(2) - .filter(isIn) // <-- IS IN filter matches textSegment0 - .build()) - .matches(); + Arrays.asList(embedding0, embedding1), Arrays.asList(textSegment0, textSegment1)); + + List> matches0 = oracleEmbeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding0) + .minScore(0d) + .maxResults(2) + .filter(isIn) // <-- IS IN filter matches textSegment0 + .build()) + .matches(); assertThat(matches0.size()).as(matches0.toString()).isEqualTo(1); assertThat(matches0.get(0).embeddingId()).isEqualTo(ids.get(0)); assertThat(matches0.get(0).embedding()).isEqualTo(embedding0); assertThat(matches0.get(0).embedded()).isEqualTo(textSegment0); - List> matches1 = - oracleEmbeddingStore.search(EmbeddingSearchRequest.builder() - .queryEmbedding(embedding0) - .minScore(0d) - .maxResults(2) - .filter(isNotIn) // <-- IS NOT IN filter matches textSegment1 - .build()) - .matches(); + List> matches1 = oracleEmbeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding0) + .minScore(0d) + .maxResults(2) + .filter(isNotIn) // <-- IS NOT IN filter matches textSegment1 + .build()) + .matches(); assertThat(matches1.size()).as(matches1.toString()).isEqualTo(1); assertThat(matches1.get(0).embeddingId()).isEqualTo(ids.get(1)); @@ -98,36 +94,48 @@ public void testNonUniformCollection() { public void testSQLType() throws SQLException { assertThat(SQLFilters.create(new IsEqualTo("x", Integer.MIN_VALUE), (key, type) -> { - assertThat(key).isEqualTo("x"); - assertThat(type).isEqualTo(OracleType.NUMBER); - return key; - }).toSQL()).isEqualTo("NVL(x = ?, false)"); + assertThat(key).isEqualTo("x"); + assertThat(type).isEqualTo(OracleType.NUMBER); + return key; + }) + .toSQL()) + .isEqualTo("NVL(x = ?, false)"); assertThat(SQLFilters.create(new IsNotEqualTo("x", Long.MAX_VALUE), (key, type) -> { - assertThat(key).isEqualTo("x"); - assertThat(type).isEqualTo(OracleType.NUMBER); - return key; - }).toSQL()).isEqualTo("NVL(x <> ?, true)"); + assertThat(key).isEqualTo("x"); + assertThat(type).isEqualTo(OracleType.NUMBER); + return key; + }) + .toSQL()) + .isEqualTo("NVL(x <> ?, true)"); assertThat(SQLFilters.create(new IsGreaterThan("x", Float.MAX_VALUE), (key, type) -> { - assertThat(key).isEqualTo("x"); - assertThat(type).isEqualTo(OracleType.BINARY_FLOAT); // REAL is 32-bit floating point - return key; - }).toSQL()).isEqualTo("NVL(x > ?, false)"); + assertThat(key).isEqualTo("x"); + assertThat(type).isEqualTo(OracleType.BINARY_FLOAT); // REAL is 32-bit floating point + return key; + }) + .toSQL()) + .isEqualTo("NVL(x > ?, false)"); assertThat(SQLFilters.create(new IsLessThan("x", Double.MIN_VALUE), (key, type) -> { - assertThat(key).isEqualTo("x"); - assertThat(type).isEqualTo(OracleType.BINARY_DOUBLE); // REAL is 64-bit floating point - return key; - }).toSQL()).isEqualTo("NVL(x < ?, false)"); + assertThat(key).isEqualTo("x"); + assertThat(type).isEqualTo(OracleType.BINARY_DOUBLE); // REAL is 64-bit floating point + return key; + }) + .toSQL()) + .isEqualTo("NVL(x < ?, false)"); assertThat(SQLFilters.create(MetadataFilterBuilder.metadataKey("x").isIn("a", "b"), (key, type) -> { - assertThat(key).isEqualTo("x"); - assertThat(type).isEqualTo(OracleType.CLOB); - return key; - }).toSQL()).isEqualTo("(NVL(DBMS_LOB.COMPARE(x, ?) = 0, false) OR NVL(DBMS_LOB.COMPARE(x, ?) = 0, false))"); + assertThat(key).isEqualTo("x"); + assertThat(type).isEqualTo(OracleType.CLOB); + return key; + }) + .toSQL()) + .isEqualTo("(NVL(DBMS_LOB.COMPARE(x, ?) = 0, false) OR NVL(DBMS_LOB.COMPARE(x, ?) = 0, false))"); // CLOB is lossless for all Java Strings (assuming the database character set is UTF-8) assertThat(SQLFilters.create(MetadataFilterBuilder.metadataKey("x").isNotIn("c", "d"), (key, type) -> { - assertThat(key).isEqualTo("x"); - assertThat(type).isEqualTo(OracleType.CLOB); - return key; - }).toSQL()).isEqualTo("NOT((NVL(DBMS_LOB.COMPARE(x, ?) = 0, false) OR NVL(DBMS_LOB.COMPARE(x, ?) = 0, false)))"); + assertThat(key).isEqualTo("x"); + assertThat(type).isEqualTo(OracleType.CLOB); + return key; + }) + .toSQL()) + .isEqualTo("NOT((NVL(DBMS_LOB.COMPARE(x, ?) = 0, false) OR NVL(DBMS_LOB.COMPARE(x, ?) = 0, false)))"); // CLOB is lossless for all Java Strings (assuming the database character set is UTF-8) } @@ -143,7 +151,7 @@ public void testSQLType() throws SQLException { */ private void verifyLosslessConversion(SQLType sqlType, Object javaObject) throws SQLException { try (Connection connection = CommonTestOperations.getDataSource().getConnection(); - PreparedStatement preparedStatement = connection.prepareStatement("SELECT ? FROM sys.dual")) { + PreparedStatement preparedStatement = connection.prepareStatement("SELECT ? FROM sys.dual")) { preparedStatement.setObject(1, javaObject, sqlType); try (ResultSet resultSet = preparedStatement.executeQuery()) { @@ -152,5 +160,4 @@ private void verifyLosslessConversion(SQLType sqlType, Object javaObject) throws } } } - -} \ No newline at end of file +}