diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java index 7341e707..08ee0126 100644 --- a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java +++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java @@ -3,85 +3,20 @@ import org.springframework.boot.context.properties.ConfigurationProperties; @ConfigurationProperties(prefix = PgVectorDataSourceProperties.PREFIX) -public class PgVectorDataSourceProperties { - +public record PgVectorDataSourceProperties( + boolean enabled, + String host, + String user, + String password, + Integer port, + String database +) { static final String PREFIX = "langchain4j.pgvector.datasource"; /** - * Enable postgres datasource configuration, default value false. - */ - private boolean enabled = false; - - /** - * The pgvector database host. - */ - private String host; - - /** - * The pgvector database user. - */ - private String user; - - /** - * The pgvector database password. + * Provide a default constructor that sets the default value of enabled to false. */ - private String password; - - /** - * The pgvector database port. - */ - private Integer port; - - /** - * The pgvector database name. - */ - private String database; - - public boolean isEnabled() { - return enabled; - } - - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - - public String getHost() { - return host; - } - - public void setHost(String host) { - this.host = host; - } - - public String getUser() { - return user; - } - - public void setUser(String user) { - this.user = user; - } - - public String getPassword() { - return password; - } - - public void setPassword(String password) { - this.password = password; - } - - public Integer getPort() { - return port; - } - - public void setPort(Integer port) { - this.port = port; - } - - public String getDatabase() { - return database; - } - - public void setDatabase(String database) { - this.database = database; + public PgVectorDataSourceProperties() { + this(false, null, null, null, null, null); } } diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java index 418f7be3..34f7ada9 100644 --- a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java +++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java @@ -4,11 +4,13 @@ import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.lang.Nullable; @@ -16,6 +18,7 @@ import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.SQLException; +import java.util.Map; import java.util.Optional; import static dev.langchain4j.internal.ValidationUtils.*; @@ -29,12 +32,40 @@ public class PgVectorEmbeddingStoreAutoConfiguration { private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStoreAutoConfiguration.class); + private final ApplicationContext applicationContext; + + public PgVectorEmbeddingStoreAutoConfiguration(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + @Bean @ConditionalOnMissingBean @ConditionalOnBean(DataSource.class) @ConditionalOnProperty(prefix = PgVectorDataSourceProperties.PREFIX, name = "enabled", havingValue = "false") - public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithExistingDataSource(DataSource dataSource, PgVectorEmbeddingStoreProperties properties, - @Nullable EmbeddingModel embeddingModel) { + public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithExistingDataSource(ObjectProvider dataSources, PgVectorEmbeddingStoreProperties properties, + @Nullable EmbeddingModel embeddingModel) { + + // The PostgreSQL data source is selected based on the configured dataSourceBeanName or automatically. + DataSource dataSource = dataSources.stream() + .filter(ds -> { + // Preferentially matches the configured dataSourceBeanName. + String beanName = properties.getDataSourceBeanName(); + if (beanName != null && !beanName.isEmpty()) { + String actualBeanName = getBeanNameForDataSource(ds); + return beanName.equals(actualBeanName); + } + return false; + }) + .findFirst() + // If no dataSourceBeanName is specified, the first PostgreSQL data source is selected. + .orElseGet(() -> dataSources.stream() + .filter(this::isPostgresqlDataSource) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No suitable PostgreSQL DataSource found in the application context. " + + "Please configure a valid PostgreSQL DataSource."))); + + log.info("Using DataSource bean: {}", dataSource.getClass().getSimpleName()); + // Check if the context's data source is a Postgres datasource ensureTrue(isPostgresqlDataSource(dataSource), "The DataSource in Spring Context is not a Postgres datasource, you need to manually specify the Postgres datasource configuration via 'langchain4j.pgvector.datasource'."); @@ -58,11 +89,11 @@ public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithCustomDataSource(PgVecto Integer dimension = Optional.ofNullable(properties.getDimension()).orElseGet(() -> embeddingModel == null ? null : embeddingModel.dimension()); return PgVectorEmbeddingStore.builder() - .host(dataSourceProperties.getHost()) - .port(dataSourceProperties.getPort()) - .user(dataSourceProperties.getUser()) - .password(dataSourceProperties.getPassword()) - .database(dataSourceProperties.getDatabase()) + .host(dataSourceProperties.host()) + .port(dataSourceProperties.port()) + .user(dataSourceProperties.user()) + .password(dataSourceProperties.password()) + .database(dataSourceProperties.database()) .table(properties.getTable()) .createTable(properties.getCreateTable()) .dimension(dimension) @@ -85,4 +116,18 @@ private boolean isPostgresqlDataSource(DataSource dataSource) { return false; } } + + /** + * Get the BeanName of the DataSource instance from the ApplicationContext. + * @param dataSource Target DataSource instance. + * @return bean name of target DataSource . + */ + private String getBeanNameForDataSource(DataSource dataSource) { + // 遍历所有 DataSource Bean,找到与当前实例匹配的 Bean 名称 + return applicationContext.getBeansOfType(DataSource.class).entrySet().stream() + .filter(entry -> entry.getValue().equals(dataSource)) + .map(Map.Entry::getKey) + .findFirst() + .orElse(null); + } } diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java index 63e98294..3febd8d7 100644 --- a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java +++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java @@ -32,6 +32,9 @@ public class PgVectorEmbeddingStoreProperties { */ private Integer indexListSize; + private String dataSourceBeanName; + + public String getTable() { return table; } @@ -71,4 +74,12 @@ public Integer getIndexListSize() { public void setIndexListSize(Integer indexListSize) { this.indexListSize = indexListSize; } + + public String getDataSourceBeanName() { + return dataSourceBeanName; + } + + public void setDataSourceBeanName(String dataSourceBeanName) { + this.dataSourceBeanName = dataSourceBeanName; + } } diff --git a/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java b/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java new file mode 100644 index 00000000..76232487 --- /dev/null +++ b/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java @@ -0,0 +1,209 @@ +package dev.langchain4j.store.embedding.pgvector.spring; + +import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.BeanCreationException; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import javax.sql.DataSource; + +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.Statement; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +class PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(PgVectorEmbeddingStoreAutoConfiguration.class)) + .withPropertyValues( + "langchain4j.pgvector.enabled=true", + "langchain4j.pgvector.datasource.enabled=false", + "langchain4j.pgvector.datasource.host=localhost", + "langchain4j.pgvector.datasource.port=5432", + "langchain4j.pgvector.datasource.user=testuser", + "langchain4j.pgvector.datasource.password=testpassword", + "langchain4j.pgvector.datasource.database=testdb", + "langchain4j.pgvector.table=embedding_table", + "langchain4j.pgvector.create-table=true", + "langchain4j.pgvector.dimension=768", + "langchain4j.pgvector.use-index=true", + "langchain4j.pgvector.index-list-size=10" + ); + + @Test + void testAutoConfigurationWithExistingDataSource() { + contextRunner + .withUserConfiguration(ExistingDataSourceConfig.class) + .run(context -> { + assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class); + + PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class); + assertThat(store).isNotNull(); + + DataSource dataSource = context.getBean(DataSource.class); + assertThat(dataSource).isNotNull(); + }); + } + + @Test + void testAutoConfigurationWithMultipleDataSourcesOfConfiguredTargetDataSourceBeanName() { + contextRunner + .withUserConfiguration(MultipleDataSourceConfig.class) + .withPropertyValues("langchain4j.pgvector.datasource-bean-name=secondaryDataSource") + .run(context -> { + // Verify that the PgVectorEmbeddingStore is correctly registered. + assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class); + + // Get PgVectorEmbeddingStore instance + PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class); + assertThat(store).isNotNull(); + + // Get DataSource instance + DataSource secondaryDataSource = context.getBean("secondaryDataSource", DataSource.class); + assertThat(secondaryDataSource).isNotNull(); + + // Get the DataSource of the PgVectorEmbeddingStore using reflection. + DataSource storeDataSource = getDataSourceFromStore(store); + + // Verify that the DataSource is consistent + assertThat(storeDataSource).isSameAs(secondaryDataSource); + }); + } + + @Test + void testAutoConfigurationWithMultipleDataSourcesOfNonConfiguredTargetDataSourceBeanName() { + contextRunner + .withUserConfiguration(MultipleDataSourceConfig.class) + .run(context -> { + // Verify that the PgVectorEmbeddingStore is correctly registered. + assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class); + + // Get PgVectorEmbeddingStore instance + PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class); + assertThat(store).isNotNull(); + + // Get DataSource instance + DataSource primaryDataSource = context.getBean("primaryDataSource", DataSource.class); + assertThat(primaryDataSource).isNotNull(); + + // Get the DataSource of the PgVectorEmbeddingStore using reflection. + DataSource storeDataSource = getDataSourceFromStore(store); + + // Verify that the DataSource is consistent + assertThat(storeDataSource).isSameAs(primaryDataSource); + }); + } + + private DataSource getDataSourceFromStore(PgVectorEmbeddingStore store) { + try { + // Let's assume that PgVectorEmbeddingStore has a field named “datasource” inside it. + Field dataSourceField = PgVectorEmbeddingStore.class.getDeclaredField("datasource"); + dataSourceField.setAccessible(true); + return (DataSource) dataSourceField.get(store); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException("Failed to access DataSource field from PgVectorEmbeddingStore", e); + } + } + + @Test + void testAutoConfigurationWithoutPostgresDataSource() { + contextRunner + .withUserConfiguration(NonPostgresDataSourceConfig.class) + .run(context -> { + // Verification context startup failure + Throwable startupFailure = context.getStartupFailure(); + assertThat(startupFailure).isNotNull(); // Make sure there are startup failure exceptions + assertThat(startupFailure) + .isInstanceOf(BeanCreationException.class) // Validating Exception Types + .hasRootCauseInstanceOf(IllegalStateException.class) // Verification of root cause type + .hasMessageContaining("No suitable PostgreSQL DataSource found in the application context"); // 验证异常信息 + }); + } + + @Test + void testAutoConfigurationDisabled() { + contextRunner + .withPropertyValues("langchain4j.pgvector.enabled=false") + .run(context -> assertThat(context).doesNotHaveBean(PgVectorEmbeddingStore.class)); + } + + + private static DataSource mockPostgreDataSource() throws Exception { + // Mock DataSource + DataSource mockDataSource = mock(DataSource.class); + + // Mock Connection + Connection mockConnection = mock(Connection.class); + when(mockDataSource.getConnection()).thenReturn(mockConnection); + + // Mock DatabaseMetaData + DatabaseMetaData mockMetaData = mock(DatabaseMetaData.class); + when(mockConnection.getMetaData()).thenReturn(mockMetaData); + when(mockMetaData.getURL()).thenReturn("jdbc:postgresql://localhost:5432/testdb"); + + // Mock PGConnection (PostgreSQL-specific connection) + org.postgresql.PGConnection mockPGConnection = mock(org.postgresql.PGConnection.class); + when(mockConnection.unwrap(org.postgresql.PGConnection.class)).thenReturn(mockPGConnection); + + // Mock PGConnection's addDataType method + doNothing().when(mockPGConnection).addDataType(anyString(), any(Class.class)); + + // Mock Statement + Statement mockStatement = mock(Statement.class); + when(mockConnection.createStatement()).thenReturn(mockStatement); + + // Mock SQL Execution (e.g., table creation or updates) + when(mockStatement.executeUpdate(anyString())).thenReturn(1); + + return mockDataSource; + } + + @Configuration + static class ExistingDataSourceConfig { + + @Bean + public DataSource dataSource() throws Exception { + return mockPostgreDataSource(); + } + } + + @Configuration + static class MultipleDataSourceConfig { + + @Bean + public DataSource primaryDataSource() throws Exception { + return mockPostgreDataSource(); + } + + @Bean + public DataSource secondaryDataSource() throws Exception { + return mockPostgreDataSource(); + } + } + + @Configuration + static class NonPostgresDataSourceConfig { + + @Bean + public DataSource dataSource() throws Exception { + // Mock a non-PostgreSQL DataSource + DataSource mockDataSource = mock(DataSource.class); + Connection mockConnection = mock(Connection.class); + DatabaseMetaData mockMetaData = mock(DatabaseMetaData.class); + + when(mockDataSource.getConnection()).thenReturn(mockConnection); + when(mockConnection.getMetaData()).thenReturn(mockMetaData); + when(mockMetaData.getURL()).thenReturn("jdbc:mysql://localhost:3306/testdb"); + + return mockDataSource; + } + } +}