Skip to content

Commit

Permalink
Extend integration tests.
Browse files Browse the repository at this point in the history
See #4706
Original pull request: #4882
  • Loading branch information
christophstrobl authored and mp911de committed Feb 4, 2025
1 parent e8e110e commit 2d7d7bf
Show file tree
Hide file tree
Showing 26 changed files with 1,119 additions and 387 deletions.
14 changes: 14 additions & 0 deletions spring-data-mongodb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,20 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<version>${testcontainers}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>mongodb</artifactId>
<version>${testcontainers}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>jakarta.transaction</groupId>
<artifactId>jakarta.transaction-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ public class DefaultIndexOperations implements IndexOperations {

private static final String PARTIAL_FILTER_EXPRESSION_KEY = "partialFilterExpression";

protected final String collectionName;
protected final QueryMapper mapper;
protected final @Nullable Class<?> type;
private final String collectionName;
private final QueryMapper mapper;
private final @Nullable Class<?> type;

protected final MongoOperations mongoOperations;
private final MongoOperations mongoOperations;

/**
* Creates a new {@link DefaultIndexOperations}.
Expand Down Expand Up @@ -132,7 +132,7 @@ public String ensureIndex(IndexDefinition indexDefinition) {
}

@Nullable
protected MongoPersistentEntity<?> lookupPersistentEntity(@Nullable Class<?> entityType, String collection) {
private MongoPersistentEntity<?> lookupPersistentEntity(@Nullable Class<?> entityType, String collection) {

if (entityType != null) {
return mapper.getMappingContext().getRequiredPersistentEntity(entityType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ private JsonSchemaProperty computeSchemaForProperty(List<MongoPersistentProperty
Class<?> rawTargetType = computeTargetType(property); // target type before conversion
Class<?> targetType = converter.getTypeMapper().getWriteTargetTypeFor(rawTargetType); // conversion target type

if((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) {
targetType = rawTargetType;
}

if (!isCollection(property) && ObjectUtils.nullSafeEquals(rawTargetType, targetType)) {
if (property.isEntity() || mergeProperties.containsKey(stringPath)) {
List<JsonSchemaProperty> targetProperties = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import org.bson.BinaryVector;
import org.bson.Document;

import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.mapping.MongoVector;
Expand All @@ -38,7 +37,6 @@
/**
* Performs a semantic search on data in your Atlas cluster. This stage is only available for Atlas Vector Search.
* Vector data must be less than or equal to 4096 dimensions in width.
* <p>
* <h3>Limitations</h3> You cannot use this stage together with:
* <ul>
* <li>{@link org.springframework.data.mongodb.core.aggregation.LookupOperation Lookup} stages</li>
Expand Down Expand Up @@ -452,6 +450,18 @@ default LimitContributor vector(float... vector) {
return vector(Vector.of(vector));
}

/**
* Array of byte numbers that represent the query vector. The number type must match the indexed field value type.
* Otherwise, Atlas Vector Search doesn't return any results or errors.
*
* @param vector the query vector.
* @return
*/
@Contract("_ -> this")
default LimitContributor vector(byte... vector) {
return vector(BinaryVector.int8Vector(vector));
}

/**
* Array of double numbers that represent the query vector. The number type must match the indexed field value type.
* Otherwise, Atlas Vector Search doesn't return any results or errors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package org.springframework.data.mongodb.core.convert;

import static org.springframework.data.convert.ConverterBuilder.*;
import static org.springframework.data.convert.ConverterBuilder.reading;

import java.math.BigDecimal;
import java.math.BigInteger;
Expand Down Expand Up @@ -47,7 +47,6 @@
import org.bson.types.Code;
import org.bson.types.Decimal128;
import org.bson.types.ObjectId;

import org.springframework.core.convert.ConversionFailedException;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.converter.ConditionalConverter;
Expand Down Expand Up @@ -119,6 +118,8 @@ static Collection<Object> getConvertersToRegister() {
converters.add(reading(BsonUndefined.class, Object.class, it -> null));
converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString));

converters.add(ByteArrayConverterFactory.INSTANCE);

return converters;
}

Expand Down Expand Up @@ -473,6 +474,48 @@ public Vector convert(BinaryVector source) {
}
}

@WritingConverter
enum ByteArrayConverterFactory implements ConverterFactory<byte[], Object>, ConditionalConverter {

INSTANCE;

@Override
public <T> Converter<byte[], T> getConverter(Class<T> targetType) {
return new ByteArrayConverter<>(targetType);
}

@Override
public boolean matches(TypeDescriptor sourceType, TypeDescriptor targetType) {
return targetType.getType() != Object.class && !sourceType.equals(targetType);
}

private final static class ByteArrayConverter<T> implements Converter<byte[], T> {

private final Class<T> targetType;

/**
* Creates a new {@link ByteArrayConverter} for the given target type.
*
* @param targetType must not be {@literal null}.
*/
public ByteArrayConverter(Class<T> targetType) {

Assert.notNull(targetType, "Target type must not be null");

this.targetType = targetType;
}

@Override
public T convert(byte[] source) {

if (this.targetType == BinaryVector.class) {
return (T) BinaryVector.int8Vector(source);
}
return (T) source;
}
}
}

/**
* {@link ConverterFactory} implementation converting {@link AtomicLong} into {@link Long}.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2024 the original author or authors.
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,24 +18,23 @@
import java.util.ArrayList;
import java.util.List;

import org.bson.BsonString;
import org.bson.Document;

import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.util.TypeInformation;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import com.mongodb.client.model.SearchIndexModel;
import com.mongodb.client.model.SearchIndexType;

/**
* @author Christoph Strobl
* @author Mark Paluch
* @since 3.5
* @since 4.5
*/
public class DefaultSearchIndexOperations implements SearchIndexOperations {

Expand All @@ -48,6 +47,7 @@ public DefaultSearchIndexOperations(MongoOperations mongoOperations, Class<?> ty
}

public DefaultSearchIndexOperations(MongoOperations mongoOperations, String collectionName, @Nullable Class<?> type) {

this.collectionName = collectionName;

if (type != null) {
Expand All @@ -63,80 +63,63 @@ public DefaultSearchIndexOperations(MongoOperations mongoOperations, String coll
}

@Override
public String ensureIndex(SearchIndexDefinition indexDefinition) {

if (!(indexDefinition instanceof VectorIndex vsi)) {
throw new IllegalStateException("Index definitions must be of type VectorIndex");
}
public String createIndex(SearchIndexDefinition indexDefinition) {

Document index = indexDefinition.getIndexDocument(entityTypeInformation,
mongoOperations.getConverter().getMappingContext());

mongoOperations.getCollection(collectionName).createSearchIndexes(List
.of(new SearchIndexModel(vsi.getName(), (Document) index.get("definition"), SearchIndexType.vectorSearch())));
mongoOperations.getCollection(collectionName)
.createSearchIndexes(List.of(new SearchIndexModel(indexDefinition.getName(),
index.get("definition", Document.class), SearchIndexType.of(new BsonString(indexDefinition.getType())))));

return vsi.getName();
return indexDefinition.getName();
}

@Override
public void updateIndex(SearchIndexDefinition index) {

if (index instanceof VectorIndex) {
throw new UnsupportedOperationException("Vector Index definitions cannot be updated");
}
public void updateIndex(SearchIndexDefinition indexDefinition) {

Document indexDocument = index.getIndexDocument(entityTypeInformation,
Document indexDocument = indexDefinition.getIndexDocument(entityTypeInformation,
mongoOperations.getConverter().getMappingContext());

mongoOperations.getCollection(collectionName).updateSearchIndex(index.getName(), indexDocument);
mongoOperations.getCollection(collectionName).updateSearchIndex(indexDefinition.getName(), indexDocument);
}

@Override
public boolean exists(String indexName) {

List<Document> indexes = mongoOperations.getCollection(collectionName).listSearchIndexes().into(new ArrayList<>());

for (Document index : indexes) {
if (index.getString("name").equals(indexName)) {
return true;
}
}

return false;
return getSearchIndex(indexName) != null;
}

@Override
public List<IndexInfo> getIndexInfo() {

AggregationResults<Document> aggregate = mongoOperations.aggregate(
Aggregation.newAggregation(context -> new Document("$listSearchIndexes", new Document())), collectionName,
Document.class);
public SearchIndexStatus status(String indexName) {

ArrayList<IndexInfo> result = new ArrayList<>();
for (Document doc : aggregate) {

List<IndexField> indexFields = new ArrayList<>();
String name = doc.getString("name");
for (Object field : doc.get("latestDefinition", Document.class).get("fields", List.class)) {

if (field instanceof Document fieldInfo) {
indexFields.add(IndexField.vector(fieldInfo.getString("path")));
}
}

result.add(new IndexInfo(indexFields, name, false, false, null, false));
}
return result;
Document searchIndex = getSearchIndex(indexName);
return searchIndex != null ? SearchIndexStatus.valueOf(searchIndex.getString("status"))
: SearchIndexStatus.DOES_NOT_EXIST;
}

@Override
public void dropAllIndexes() {
getIndexInfo().forEach(indexInfo -> dropIndex(indexInfo.getName()));
getSearchIndexes(null).forEach(indexInfo -> dropIndex(indexInfo.getString("name")));
}

@Override
public void dropIndex(String name) {
mongoOperations.getCollection(collectionName).dropSearchIndex(name);
public void dropIndex(String indexName) {
mongoOperations.getCollection(collectionName).dropSearchIndex(indexName);
}

@Nullable
private Document getSearchIndex(String indexName) {

List<Document> indexes = getSearchIndexes(indexName);
return indexes.isEmpty() ? null : indexes.iterator().next();
}

private List<Document> getSearchIndexes(@Nullable String indexName) {

Document filter = StringUtils.hasText(indexName) ? new Document("name", indexName) : new Document();

return mongoOperations.getCollection(collectionName).aggregate(List.of(new Document("$listSearchIndexes", filter)))
.into(new ArrayList<>());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,23 @@ public interface IndexOperations {
*
* @param indexDefinition must not be {@literal null}.
* @return the index name.
* @deprecated in favor of {@link #createIndex(IndexDefinition)}.
*/
@Deprecated(since = "4.5", forRemoval = true)
String ensureIndex(IndexDefinition indexDefinition);

/**
* Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity
* class. If not it will be created.
*
* @param indexDefinition must not be {@literal null}.
* @return the index name.
* @since 4.5
*/
default String createIndex(IndexDefinition indexDefinition) {
return ensureIndex(indexDefinition);
}

/**
* Alters the index with given {@literal name}.
*
Expand Down
Loading

0 comments on commit 2d7d7bf

Please sign in to comment.