diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java
index 7341e707..08ee0126 100644
--- a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java
+++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorDataSourceProperties.java
@@ -3,85 +3,20 @@
import org.springframework.boot.context.properties.ConfigurationProperties;
@ConfigurationProperties(prefix = PgVectorDataSourceProperties.PREFIX)
-public class PgVectorDataSourceProperties {
-
+public record PgVectorDataSourceProperties(
+ boolean enabled,
+ String host,
+ String user,
+ String password,
+ Integer port,
+ String database
+) {
static final String PREFIX = "langchain4j.pgvector.datasource";
/**
- * Enable postgres datasource configuration, default value false
.
- */
- private boolean enabled = false;
-
- /**
- * The pgvector database host.
- */
- private String host;
-
- /**
- * The pgvector database user.
- */
- private String user;
-
- /**
- * The pgvector database password.
+ * Provide a default constructor that sets the default value of enabled to false.
*/
- private String password;
-
- /**
- * The pgvector database port.
- */
- private Integer port;
-
- /**
- * The pgvector database name.
- */
- private String database;
-
- public boolean isEnabled() {
- return enabled;
- }
-
- public void setEnabled(boolean enabled) {
- this.enabled = enabled;
- }
-
- public String getHost() {
- return host;
- }
-
- public void setHost(String host) {
- this.host = host;
- }
-
- public String getUser() {
- return user;
- }
-
- public void setUser(String user) {
- this.user = user;
- }
-
- public String getPassword() {
- return password;
- }
-
- public void setPassword(String password) {
- this.password = password;
- }
-
- public Integer getPort() {
- return port;
- }
-
- public void setPort(Integer port) {
- this.port = port;
- }
-
- public String getDatabase() {
- return database;
- }
-
- public void setDatabase(String database) {
- this.database = database;
+ public PgVectorDataSourceProperties() {
+ this(false, null, null, null, null, null);
}
}
diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java
index 418f7be3..34f7ada9 100644
--- a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java
+++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfiguration.java
@@ -4,11 +4,13 @@
import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.lang.Nullable;
@@ -16,6 +18,7 @@
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.SQLException;
+import java.util.Map;
import java.util.Optional;
import static dev.langchain4j.internal.ValidationUtils.*;
@@ -29,12 +32,40 @@ public class PgVectorEmbeddingStoreAutoConfiguration {
private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStoreAutoConfiguration.class);
+ private final ApplicationContext applicationContext;
+
+ public PgVectorEmbeddingStoreAutoConfiguration(ApplicationContext applicationContext) {
+ this.applicationContext = applicationContext;
+ }
+
@Bean
@ConditionalOnMissingBean
@ConditionalOnBean(DataSource.class)
@ConditionalOnProperty(prefix = PgVectorDataSourceProperties.PREFIX, name = "enabled", havingValue = "false")
- public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithExistingDataSource(DataSource dataSource, PgVectorEmbeddingStoreProperties properties,
- @Nullable EmbeddingModel embeddingModel) {
+ public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithExistingDataSource(ObjectProvider dataSources, PgVectorEmbeddingStoreProperties properties,
+ @Nullable EmbeddingModel embeddingModel) {
+
+ // The PostgreSQL data source is selected based on the configured dataSourceBeanName or automatically.
+ DataSource dataSource = dataSources.stream()
+ .filter(ds -> {
+ // Preferentially matches the configured dataSourceBeanName.
+ String beanName = properties.getDataSourceBeanName();
+ if (beanName != null && !beanName.isEmpty()) {
+ String actualBeanName = getBeanNameForDataSource(ds);
+ return beanName.equals(actualBeanName);
+ }
+ return false;
+ })
+ .findFirst()
+ // If no dataSourceBeanName is specified, the first PostgreSQL data source is selected.
+ .orElseGet(() -> dataSources.stream()
+ .filter(this::isPostgresqlDataSource)
+ .findFirst()
+ .orElseThrow(() -> new IllegalStateException("No suitable PostgreSQL DataSource found in the application context. "
+ + "Please configure a valid PostgreSQL DataSource.")));
+
+ log.info("Using DataSource bean: {}", dataSource.getClass().getSimpleName());
+
// Check if the context's data source is a Postgres datasource
ensureTrue(isPostgresqlDataSource(dataSource), "The DataSource in Spring Context is not a Postgres datasource, you need to manually specify the Postgres datasource configuration via 'langchain4j.pgvector.datasource'.");
@@ -58,11 +89,11 @@ public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithCustomDataSource(PgVecto
Integer dimension = Optional.ofNullable(properties.getDimension()).orElseGet(() -> embeddingModel == null ? null : embeddingModel.dimension());
return PgVectorEmbeddingStore.builder()
- .host(dataSourceProperties.getHost())
- .port(dataSourceProperties.getPort())
- .user(dataSourceProperties.getUser())
- .password(dataSourceProperties.getPassword())
- .database(dataSourceProperties.getDatabase())
+ .host(dataSourceProperties.host())
+ .port(dataSourceProperties.port())
+ .user(dataSourceProperties.user())
+ .password(dataSourceProperties.password())
+ .database(dataSourceProperties.database())
.table(properties.getTable())
.createTable(properties.getCreateTable())
.dimension(dimension)
@@ -85,4 +116,18 @@ private boolean isPostgresqlDataSource(DataSource dataSource) {
return false;
}
}
+
+ /**
+ * Get the BeanName of the DataSource instance from the ApplicationContext.
+ * @param dataSource Target DataSource instance.
+ * @return bean name of target DataSource .
+ */
+ private String getBeanNameForDataSource(DataSource dataSource) {
+ // 遍历所有 DataSource Bean,找到与当前实例匹配的 Bean 名称
+ return applicationContext.getBeansOfType(DataSource.class).entrySet().stream()
+ .filter(entry -> entry.getValue().equals(dataSource))
+ .map(Map.Entry::getKey)
+ .findFirst()
+ .orElse(null);
+ }
}
diff --git a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java
index 63e98294..3febd8d7 100644
--- a/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java
+++ b/langchain4j-pgvector-spring-boot-starter/src/main/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreProperties.java
@@ -32,6 +32,9 @@ public class PgVectorEmbeddingStoreProperties {
*/
private Integer indexListSize;
+ private String dataSourceBeanName;
+
+
public String getTable() {
return table;
}
@@ -71,4 +74,12 @@ public Integer getIndexListSize() {
public void setIndexListSize(Integer indexListSize) {
this.indexListSize = indexListSize;
}
+
+ public String getDataSourceBeanName() {
+ return dataSourceBeanName;
+ }
+
+ public void setDataSourceBeanName(String dataSourceBeanName) {
+ this.dataSourceBeanName = dataSourceBeanName;
+ }
}
diff --git a/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java b/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java
new file mode 100644
index 00000000..76232487
--- /dev/null
+++ b/langchain4j-pgvector-spring-boot-starter/src/test/java/dev/langchain4j/store/embedding/pgvector/spring/PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT.java
@@ -0,0 +1,209 @@
+package dev.langchain4j.store.embedding.pgvector.spring;
+
+import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
+import org.junit.jupiter.api.Test;
+import org.springframework.beans.factory.BeanCreationException;
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+import javax.sql.DataSource;
+
+import java.lang.reflect.Field;
+import java.sql.Connection;
+import java.sql.DatabaseMetaData;
+import java.sql.Statement;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.*;
+
+class PgVectorEmbeddingStoreAutoConfigurationForDataSourceIT {
+
+ private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
+ .withConfiguration(AutoConfigurations.of(PgVectorEmbeddingStoreAutoConfiguration.class))
+ .withPropertyValues(
+ "langchain4j.pgvector.enabled=true",
+ "langchain4j.pgvector.datasource.enabled=false",
+ "langchain4j.pgvector.datasource.host=localhost",
+ "langchain4j.pgvector.datasource.port=5432",
+ "langchain4j.pgvector.datasource.user=testuser",
+ "langchain4j.pgvector.datasource.password=testpassword",
+ "langchain4j.pgvector.datasource.database=testdb",
+ "langchain4j.pgvector.table=embedding_table",
+ "langchain4j.pgvector.create-table=true",
+ "langchain4j.pgvector.dimension=768",
+ "langchain4j.pgvector.use-index=true",
+ "langchain4j.pgvector.index-list-size=10"
+ );
+
+ @Test
+ void testAutoConfigurationWithExistingDataSource() {
+ contextRunner
+ .withUserConfiguration(ExistingDataSourceConfig.class)
+ .run(context -> {
+ assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class);
+
+ PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class);
+ assertThat(store).isNotNull();
+
+ DataSource dataSource = context.getBean(DataSource.class);
+ assertThat(dataSource).isNotNull();
+ });
+ }
+
+ @Test
+ void testAutoConfigurationWithMultipleDataSourcesOfConfiguredTargetDataSourceBeanName() {
+ contextRunner
+ .withUserConfiguration(MultipleDataSourceConfig.class)
+ .withPropertyValues("langchain4j.pgvector.datasource-bean-name=secondaryDataSource")
+ .run(context -> {
+ // Verify that the PgVectorEmbeddingStore is correctly registered.
+ assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class);
+
+ // Get PgVectorEmbeddingStore instance
+ PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class);
+ assertThat(store).isNotNull();
+
+ // Get DataSource instance
+ DataSource secondaryDataSource = context.getBean("secondaryDataSource", DataSource.class);
+ assertThat(secondaryDataSource).isNotNull();
+
+ // Get the DataSource of the PgVectorEmbeddingStore using reflection.
+ DataSource storeDataSource = getDataSourceFromStore(store);
+
+ // Verify that the DataSource is consistent
+ assertThat(storeDataSource).isSameAs(secondaryDataSource);
+ });
+ }
+
+ @Test
+ void testAutoConfigurationWithMultipleDataSourcesOfNonConfiguredTargetDataSourceBeanName() {
+ contextRunner
+ .withUserConfiguration(MultipleDataSourceConfig.class)
+ .run(context -> {
+ // Verify that the PgVectorEmbeddingStore is correctly registered.
+ assertThat(context).hasSingleBean(PgVectorEmbeddingStore.class);
+
+ // Get PgVectorEmbeddingStore instance
+ PgVectorEmbeddingStore store = context.getBean(PgVectorEmbeddingStore.class);
+ assertThat(store).isNotNull();
+
+ // Get DataSource instance
+ DataSource primaryDataSource = context.getBean("primaryDataSource", DataSource.class);
+ assertThat(primaryDataSource).isNotNull();
+
+ // Get the DataSource of the PgVectorEmbeddingStore using reflection.
+ DataSource storeDataSource = getDataSourceFromStore(store);
+
+ // Verify that the DataSource is consistent
+ assertThat(storeDataSource).isSameAs(primaryDataSource);
+ });
+ }
+
+ private DataSource getDataSourceFromStore(PgVectorEmbeddingStore store) {
+ try {
+ // Let's assume that PgVectorEmbeddingStore has a field named “datasource” inside it.
+ Field dataSourceField = PgVectorEmbeddingStore.class.getDeclaredField("datasource");
+ dataSourceField.setAccessible(true);
+ return (DataSource) dataSourceField.get(store);
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException("Failed to access DataSource field from PgVectorEmbeddingStore", e);
+ }
+ }
+
+ @Test
+ void testAutoConfigurationWithoutPostgresDataSource() {
+ contextRunner
+ .withUserConfiguration(NonPostgresDataSourceConfig.class)
+ .run(context -> {
+ // Verification context startup failure
+ Throwable startupFailure = context.getStartupFailure();
+ assertThat(startupFailure).isNotNull(); // Make sure there are startup failure exceptions
+ assertThat(startupFailure)
+ .isInstanceOf(BeanCreationException.class) // Validating Exception Types
+ .hasRootCauseInstanceOf(IllegalStateException.class) // Verification of root cause type
+ .hasMessageContaining("No suitable PostgreSQL DataSource found in the application context"); // 验证异常信息
+ });
+ }
+
+ @Test
+ void testAutoConfigurationDisabled() {
+ contextRunner
+ .withPropertyValues("langchain4j.pgvector.enabled=false")
+ .run(context -> assertThat(context).doesNotHaveBean(PgVectorEmbeddingStore.class));
+ }
+
+
+ private static DataSource mockPostgreDataSource() throws Exception {
+ // Mock DataSource
+ DataSource mockDataSource = mock(DataSource.class);
+
+ // Mock Connection
+ Connection mockConnection = mock(Connection.class);
+ when(mockDataSource.getConnection()).thenReturn(mockConnection);
+
+ // Mock DatabaseMetaData
+ DatabaseMetaData mockMetaData = mock(DatabaseMetaData.class);
+ when(mockConnection.getMetaData()).thenReturn(mockMetaData);
+ when(mockMetaData.getURL()).thenReturn("jdbc:postgresql://localhost:5432/testdb");
+
+ // Mock PGConnection (PostgreSQL-specific connection)
+ org.postgresql.PGConnection mockPGConnection = mock(org.postgresql.PGConnection.class);
+ when(mockConnection.unwrap(org.postgresql.PGConnection.class)).thenReturn(mockPGConnection);
+
+ // Mock PGConnection's addDataType method
+ doNothing().when(mockPGConnection).addDataType(anyString(), any(Class.class));
+
+ // Mock Statement
+ Statement mockStatement = mock(Statement.class);
+ when(mockConnection.createStatement()).thenReturn(mockStatement);
+
+ // Mock SQL Execution (e.g., table creation or updates)
+ when(mockStatement.executeUpdate(anyString())).thenReturn(1);
+
+ return mockDataSource;
+ }
+
+ @Configuration
+ static class ExistingDataSourceConfig {
+
+ @Bean
+ public DataSource dataSource() throws Exception {
+ return mockPostgreDataSource();
+ }
+ }
+
+ @Configuration
+ static class MultipleDataSourceConfig {
+
+ @Bean
+ public DataSource primaryDataSource() throws Exception {
+ return mockPostgreDataSource();
+ }
+
+ @Bean
+ public DataSource secondaryDataSource() throws Exception {
+ return mockPostgreDataSource();
+ }
+ }
+
+ @Configuration
+ static class NonPostgresDataSourceConfig {
+
+ @Bean
+ public DataSource dataSource() throws Exception {
+ // Mock a non-PostgreSQL DataSource
+ DataSource mockDataSource = mock(DataSource.class);
+ Connection mockConnection = mock(Connection.class);
+ DatabaseMetaData mockMetaData = mock(DatabaseMetaData.class);
+
+ when(mockDataSource.getConnection()).thenReturn(mockConnection);
+ when(mockConnection.getMetaData()).thenReturn(mockMetaData);
+ when(mockMetaData.getURL()).thenReturn("jdbc:mysql://localhost:3306/testdb");
+
+ return mockDataSource;
+ }
+ }
+}