diff --git a/langchain4j-pgvector-spring-boot-starter/pom.xml b/langchain4j-pgvector-spring-boot-starter/pom.xml new file mode 100644 index 00000000..d2492648 --- /dev/null +++ b/langchain4j-pgvector-spring-boot-starter/pom.xml @@ -0,0 +1,82 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-spring + 0.37.0-SNAPSHOT + ../pom.xml + + + langchain4j-pgvector-spring-boot-starter + LangChain4j Spring Boot starter for PgVector + jar + + + + + dev.langchain4j + langchain4j-pgvector + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.boot + spring-boot-autoconfigure-processor + true + + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + org.springframework.boot + spring-boot-starter-test + test + + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + + + dev.langchain4j + langchain4j-spring-boot-tests + ${project.version} + tests + test-jar + test + + + + org.testcontainers + postgresql + test + + + + org.tinylog + tinylog-impl + test + + + + org.tinylog + slf4j-tinylog + test + + + + \ No newline at end of file diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java new file mode 100644 index 00000000..08ee0126 --- /dev/null +++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java @@ -0,0 +1,22 @@ +package dev.langchain4j.store.embedding.pgvector.spring; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +@ConfigurationProperties(prefix = PgVectorDataSourceProperties.PREFIX) +public record PgVectorDataSourceProperties( + boolean enabled, + String host, + String user, + String password, + Integer port, + String database +) { + static final String PREFIX = "langchain4j.pgvector.datasource"; + + /** + * Provide a default constructor that sets the default value of enabled to false. + */ + public PgVectorDataSourceProperties() { + this(false, null, null, null, null, null); + } +} diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java new file mode 100644 index 00000000..2e492f5c --- /dev/null +++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java @@ -0,0 +1,133 @@ +package dev.langchain4j.store.embedding.pgvector.spring; + +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.lang.Nullable; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.SQLException; +import java.util.Map; +import java.util.Optional; + +import static dev.langchain4j.internal.ValidationUtils.*; +import static dev.langchain4j.store.embedding.pgvector.spring.PgVectorEmbeddingStoreProperties.*; +import static org.springframework.util.StringUtils.startsWithIgnoreCase; + +@AutoConfiguration +@EnableConfigurationProperties({PgVectorEmbeddingStoreProperties.class, PgVectorDataSourceProperties.class}) +@ConditionalOnProperty(prefix = PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) +public class PgVectorEmbeddingStoreAutoConfiguration { + + private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStoreAutoConfiguration.class); + + private final ApplicationContext applicationContext; + + public PgVectorEmbeddingStoreAutoConfiguration(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnBean(DataSource.class) + @ConditionalOnProperty(prefix = PgVectorDataSourceProperties.PREFIX, name = "enabled", havingValue = "false") + public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithExistingDataSource(ObjectProvider dataSources, PgVectorEmbeddingStoreProperties properties, + @Nullable EmbeddingModel embeddingModel) { + + // The PostgreSQL data source is selected based on the configured dataSourceBeanName or automatically. + DataSource dataSource = dataSources.stream() + .filter(ds -> { + // Preferentially matches the configured dataSourceBeanName. + String beanName = properties.getDataSourceBeanName(); + if (beanName != null && !beanName.isEmpty()) { + String actualBeanName = getBeanNameForDataSource(ds); + return beanName.equals(actualBeanName); + } + return false; + }) + .findFirst() + // If no dataSourceBeanName is specified, the first PostgreSQL data source is selected. + .orElseGet(() -> dataSources.stream() + .filter(this::isPostgresqlDataSource) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No suitable PostgreSQL DataSource found in the application context. " + + "Please configure a valid PostgreSQL DataSource."))); + + log.info("Using DataSource bean: {}", dataSource.getClass().getSimpleName()); + + // Check if the context's data source is a Postgres datasource + ensureTrue(isPostgresqlDataSource(dataSource), "The DataSource in Spring Context is not a Postgres datasource, you need to manually specify the Postgres datasource configuration via 'langchain4j.pgvector.datasource'."); + + Integer dimension = Optional.ofNullable(properties.getDimension()).orElseGet(() -> embeddingModel == null ? null : embeddingModel.dimension()); + + return PgVectorEmbeddingStore.datasourceBuilder() + .datasource(dataSource) + .table(properties.getTable()) + .createTable(properties.getCreateTable()) + .dimension(dimension) + .useIndex(properties.getUseIndex()) + .indexListSize(properties.getIndexListSize()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = PgVectorDataSourceProperties.PREFIX, name = "enabled", havingValue = "true") + public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithCustomDataSource(PgVectorEmbeddingStoreProperties properties, PgVectorDataSourceProperties dataSourceProperties, + @Nullable EmbeddingModel embeddingModel) { + Integer dimension = Optional.ofNullable(properties.getDimension()).orElseGet(() -> embeddingModel == null ? null : embeddingModel.dimension()); + + return PgVectorEmbeddingStore.builder() + .host(dataSourceProperties.host()) + .port(dataSourceProperties.port()) + .user(dataSourceProperties.user()) + .password(dataSourceProperties.password()) + .database(dataSourceProperties.database()) + .table(properties.getTable()) + .createTable(properties.getCreateTable()) + .dimension(dimension) + .useIndex(properties.getUseIndex()) + .indexListSize(properties.getIndexListSize()) + .build(); + } + + /** + * Check if the datasource is postgresql`. + * @param dataSource instance of {@link DataSource}. + * @return true means it is a postgresql data source, otherwise it is not. + */ + private boolean isPostgresqlDataSource(DataSource dataSource) { + try (Connection connection = dataSource.getConnection()) { + DatabaseMetaData metaData = connection.getMetaData(); + return startsWithIgnoreCase(metaData.getURL(), "jdbc:postgresql"); + } catch (SQLException e) { + log.warn("Exception checking datasource driver type during PgVector auto-configuration ."); + return false; + } + } + + /** + * Get the BeanName of the DataSource instance from the ApplicationContext. + * @param dataSource Target DataSource instance. + * @return bean name of target DataSource . + */ + private String getBeanNameForDataSource(DataSource dataSource) { + // Iterate through all DataSource beans to find the bean name that matches the current instance + return applicationContext.getBeansOfType(DataSource.class).entrySet().stream() + .filter(entry -> entry.getValue().equals(dataSource)) + .map(Map.Entry::getKey) + .findFirst() + .orElse(null); + } +} diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java new file mode 100644 index 00000000..3febd8d7 --- /dev/null +++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java @@ -0,0 +1,85 @@ +package dev.langchain4j.store.embedding.pgvector.spring; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +@ConfigurationProperties(prefix = PgVectorEmbeddingStoreProperties.PREFIX) +public class PgVectorEmbeddingStoreProperties { + + static final String PREFIX = "langchain4j.pgvector"; + + /** + * The pgvector database table. + */ + private String table; + + /** + * The vector dimension. + */ + private Integer dimension; + + /** + * Should create table automatically, default value is false. + */ + private Boolean createTable; + + /** + * Should use IVFFlat index. + */ + private Boolean useIndex; + + /** + * The IVFFlat number of lists. + */ + private Integer indexListSize; + + private String dataSourceBeanName; + + + public String getTable() { + return table; + } + + public void setTable(String table) { + this.table = table; + } + + public Integer getDimension() { + return dimension; + } + + public void setDimension(Integer dimension) { + this.dimension = dimension; + } + + public Boolean getCreateTable() { + return createTable; + } + + public void setCreateTable(Boolean createTable) { + this.createTable = createTable; + } + + public Boolean getUseIndex() { + return useIndex; + } + + public void setUseIndex(Boolean useIndex) { + this.useIndex = useIndex; + } + + public Integer getIndexListSize() { + return indexListSize; + } + + public void setIndexListSize(Integer indexListSize) { + this.indexListSize = indexListSize; + } + + public String getDataSourceBeanName() { + return dataSourceBeanName; + } + + public void setDataSourceBeanName(String dataSourceBeanName) { + this.dataSourceBeanName = dataSourceBeanName; + } +} diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/langchain4j-pgvector-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000..31b427d9 --- /dev/null +++ b/langchain4j-pgvector-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +dev.langchain4j.store.embedding.pgvector.spring.PgVectorEmbeddingStoreAutoConfiguration \ No newline at end of file diff --git a/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java b/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java new file mode 100644 index 00000000..30e58a6c --- /dev/null +++ b/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java @@ -0,0 +1,209 @@ +package dev.langchain4j.store.embedding.pgvector.spring; + +import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.BeanCreationException; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import javax.sql.DataSource; + +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.Statement; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +class PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(PgVectorEmbeddingStoreAutoConfiguration.class)) + .withPropertyValues( + "langchain4j.pgvector.enabled=true", + "langchain4j.pgvector.datasource.enabled=false", + "langchain4j.pgvector.datasource.host=localhost", + "langchain4j.pgvector.datasource.port=5432", + "langchain4j.pgvector.datasource.user=testuser", + "langchain4j.pgvector.datasource.password=testpassword", + "langchain4j.pgvector.datasource.database=testdb", + "langchain4j.pgvector.table=embedding_table", + "langchain4j.pgvector.create-table=true", + "langchain4j.pgvector.dimension=768", + "langchain4j.pgvector.use-index=true", + "langchain4j.pgvector.index-list-size=10" + ); + + @Test + void testAutoConfigurationWithExistingDataSource() { + contextRunner + .withUserConfiguration(ExistingDataSourceConfig.class) + .run(context -> { + assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class); + + PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class); + assertThat(store).isNotNull(); + + DataSource dataSource = context.getBean(DataSource.class); + assertThat(dataSource).isNotNull(); + }); + } + + @Test + void testAutoConfigurationWithMultipleDataSourcesOfConfiguredTargetDataSourceBeanName() { + contextRunner + .withUserConfiguration(MultipleDataSourceConfig.class) + .withPropertyValues("langchain4j.pgvector.datasource-bean-name=secondaryDataSource") + .run(context -> { + // Verify that the PgVectorEmbeddingStore is correctly registered. + assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class); + + // Get PgVectorEmbeddingStore instance + PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class); + assertThat(store).isNotNull(); + + // Get DataSource instance + DataSource secondaryDataSource = context.getBean("secondaryDataSource", DataSource.class); + assertThat(secondaryDataSource).isNotNull(); + + // Get the DataSource of the PgVectorEmbeddingStore using reflection. + DataSource storeDataSource = getDataSourceFromStore(store); + + // Verify that the DataSource is consistent + assertThat(storeDataSource).isSameAs(secondaryDataSource); + }); + } + + @Test + void testAutoConfigurationWithMultipleDataSourcesOfNonConfiguredTargetDataSourceBeanName() { + contextRunner + .withUserConfiguration(MultipleDataSourceConfig.class) + .run(context -> { + // Verify that the PgVectorEmbeddingStore is correctly registered. + assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class); + + // Get PgVectorEmbeddingStore instance + PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class); + assertThat(store).isNotNull(); + + // Get DataSource instance + DataSource primaryDataSource = context.getBean("primaryDataSource", DataSource.class); + assertThat(primaryDataSource).isNotNull(); + + // Get the DataSource of the PgVectorEmbeddingStore using reflection. + DataSource storeDataSource = getDataSourceFromStore(store); + + // Verify that the DataSource is consistent + assertThat(storeDataSource).isSameAs(primaryDataSource); + }); + } + + private DataSource getDataSourceFromStore(PgVectorEmbeddingStore store) { + try { + // Let's assume that PgVectorEmbeddingStore has a field named “datasource” inside it. + Field dataSourceField = PgVectorEmbeddingStore.class.getDeclaredField("datasource"); + dataSourceField.setAccessible(true); + return (DataSource) dataSourceField.get(store); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException("Failed to access DataSource field from PgVectorEmbeddingStore", e); + } + } + + @Test + void testAutoConfigurationWithoutPostgresDataSource() { + contextRunner + .withUserConfiguration(NonPostgresDataSourceConfig.class) + .run(context -> { + // Verification context startup failure + Throwable startupFailure = context.getStartupFailure(); + assertThat(startupFailure).isNotNull(); // Make sure there are startup failure exceptions + assertThat(startupFailure) + .isInstanceOf(BeanCreationException.class) // Validating Exception Types + .hasRootCauseInstanceOf(IllegalStateException.class) // Verification of root cause type + .hasMessageContaining("No suitable PostgreSQL DataSource found in the application context"); + }); + } + + @Test + void testAutoConfigurationDisabled() { + contextRunner + .withPropertyValues("langchain4j.pgvector.enabled=false") + .run(context -> assertThat(context).doesNotHaveBean(PgVectorEmbeddingStore.class)); + } + + + private static DataSource mockPostgreDataSource() throws Exception { + // Mock DataSource + DataSource mockDataSource = mock(DataSource.class); + + // Mock Connection + Connection mockConnection = mock(Connection.class); + when(mockDataSource.getConnection()).thenReturn(mockConnection); + + // Mock DatabaseMetaData + DatabaseMetaData mockMetaData = mock(DatabaseMetaData.class); + when(mockConnection.getMetaData()).thenReturn(mockMetaData); + when(mockMetaData.getURL()).thenReturn("jdbc:postgresql://localhost:5432/testdb"); + + // Mock PGConnection (PostgreSQL-specific connection) + org.postgresql.PGConnection mockPGConnection = mock(org.postgresql.PGConnection.class); + when(mockConnection.unwrap(org.postgresql.PGConnection.class)).thenReturn(mockPGConnection); + + // Mock PGConnection's addDataType method + doNothing().when(mockPGConnection).addDataType(anyString(), any(Class.class)); + + // Mock Statement + Statement mockStatement = mock(Statement.class); + when(mockConnection.createStatement()).thenReturn(mockStatement); + + // Mock SQL Execution (e.g., table creation or updates) + when(mockStatement.executeUpdate(anyString())).thenReturn(1); + + return mockDataSource; + } + + @Configuration + static class ExistingDataSourceConfig { + + @Bean + public DataSource dataSource() throws Exception { + return mockPostgreDataSource(); + } + } + + @Configuration + static class MultipleDataSourceConfig { + + @Bean + public DataSource primaryDataSource() throws Exception { + return mockPostgreDataSource(); + } + + @Bean + public DataSource secondaryDataSource() throws Exception { + return mockPostgreDataSource(); + } + } + + @Configuration + static class NonPostgresDataSourceConfig { + + @Bean + public DataSource dataSource() throws Exception { + // Mock a non-PostgreSQL DataSource + DataSource mockDataSource = mock(DataSource.class); + Connection mockConnection = mock(Connection.class); + DatabaseMetaData mockMetaData = mock(DatabaseMetaData.class); + + when(mockDataSource.getConnection()).thenReturn(mockConnection); + when(mockConnection.getMetaData()).thenReturn(mockMetaData); + when(mockMetaData.getURL()).thenReturn("jdbc:mysql://localhost:3306/testdb"); + + return mockDataSource; + } + } +} diff --git a/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationIT.java b/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationIT.java new file mode 100644 index 00000000..f23d3305 --- /dev/null +++ b/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationIT.java @@ -0,0 +1,75 @@ +package dev.langchain4j.store.embedding.pgvector.spring; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; +import dev.langchain4j.store.embedding.spring.EmbeddingStoreAutoConfigurationIT; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.containers.PostgreSQLContainer; + +class PgVectorEmbeddingStoreAutoConfigurationIT extends EmbeddingStoreAutoConfigurationIT { + + static PostgreSQLContainer pgVector = new PostgreSQLContainer<>("pgvector/pgvector:pg16"); + static final String DEFAULT_TABLE = "test_langchain4j_table"; + + @BeforeAll + static void beforeAll() { + pgVector.start(); + } + + @AfterAll + static void afterAll() { + pgVector.stop(); + } + + @BeforeEach + void beforeEach() { + ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(autoConfigurationClass())); + + contextRunner + .withBean(AllMiniLmL6V2QuantizedEmbeddingModel.class) + .withPropertyValues(properties()) + .run(context -> { + PgVectorEmbeddingStore embeddingStore = context.getBean(PgVectorEmbeddingStore.class); + embeddingStore.removeAll(); + }); + } + + @Override + protected Class autoConfigurationClass() { + return PgVectorEmbeddingStoreAutoConfiguration.class; + } + + @Override + protected Class> embeddingStoreClass() { + return PgVectorEmbeddingStore.class; + } + + @Override + protected String[] properties() { + return new String[]{ + "langchain4j.pgvector.datasource.enabled=true", + "langchain4j.pgvector.datasource.host=" + pgVector.getHost(), + "langchain4j.pgvector.datasource.port=" + pgVector.getMappedPort(5432), + "langchain4j.pgvector.datasource.user=" + pgVector.getUsername(), + "langchain4j.pgvector.datasource.password=" + pgVector.getPassword(), + "langchain4j.pgvector.datasource.database=" + pgVector.getDatabaseName(), + "langchain4j.pgvector.table=" + DEFAULT_TABLE, + "langchain4j.pgvector.create-table=true", + "langchain4j.pgvector.use-index=true", + "langchain4j.pgvector.index-list-size=100", + "langchain4j.pgvector.dimension=384" + }; + } + + @Override + protected String dimensionPropertyKey() { + return "langchain4j.pgvector.dimension"; + } +} diff --git a/pom.xml b/pom.xml index 8526470f..7949e24b 100644 --- a/pom.xml +++ b/pom.xml @@ -29,6 +29,7 @@ langchain4j-elasticsearch-spring-boot-starter langchain4j-redis-spring-boot-starter langchain4j-milvus-spring-boot-starter + langchain4j-pgvector-spring-boot-starter langchain4j-reactor