Skip to content

Commit

Permalink
Optimize support for multiple data sources and add additional test ca…
Browse files Browse the repository at this point in the history
…ses.
  • Loading branch information
misselvexu committed Dec 10, 2024
1 parent 8d4045f commit 875b24a
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,85 +3,20 @@
import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties(prefix = PgVectorDataSourceProperties.PREFIX)
public class PgVectorDataSourceProperties {

public record PgVectorDataSourceProperties(
boolean enabled,
String host,
String user,
String password,
Integer port,
String database
) {
static final String PREFIX = "langchain4j.pgvector.datasource";

/**
* Enable postgres datasource configuration, default value <code>false</code>.
*/
private boolean enabled = false;

/**
* The pgvector database host.
*/
private String host;

/**
* The pgvector database user.
*/
private String user;

/**
* The pgvector database password.
* Provide a default constructor that sets the default value of enabled to false.
*/
private String password;

/**
* The pgvector database port.
*/
private Integer port;

/**
* The pgvector database name.
*/
private String database;

public boolean isEnabled() {
return enabled;
}

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public String getHost() {
return host;
}

public void setHost(String host) {
this.host = host;
}

public String getUser() {
return user;
}

public void setUser(String user) {
this.user = user;
}

public String getPassword() {
return password;
}

public void setPassword(String password) {
this.password = password;
}

public Integer getPort() {
return port;
}

public void setPort(Integer port) {
this.port = port;
}

public String getDatabase() {
return database;
}

public void setDatabase(String database) {
this.database = database;
public PgVectorDataSourceProperties() {
this(false, null, null, null, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.lang.Nullable;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import java.util.Map;
import java.util.Optional;

import static dev.langchain4j.internal.ValidationUtils.*;
Expand All @@ -29,12 +32,40 @@ public class PgVectorEmbeddingStoreAutoConfiguration {

private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStoreAutoConfiguration.class);

private final ApplicationContext applicationContext;

public PgVectorEmbeddingStoreAutoConfiguration(ApplicationContext applicationContext) {
this.applicationContext = applicationContext;
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnBean(DataSource.class)
@ConditionalOnProperty(prefix = PgVectorDataSourceProperties.PREFIX, name = "enabled", havingValue = "false")
public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithExistingDataSource(DataSource dataSource, PgVectorEmbeddingStoreProperties properties,
@Nullable EmbeddingModel embeddingModel) {
public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithExistingDataSource(ObjectProvider<DataSource> dataSources, PgVectorEmbeddingStoreProperties properties,
@Nullable EmbeddingModel embeddingModel) {

// The PostgreSQL data source is selected based on the configured dataSourceBeanName or automatically.
DataSource dataSource = dataSources.stream()
.filter(ds -> {
// Preferentially matches the configured dataSourceBeanName.
String beanName = properties.getDataSourceBeanName();
if (beanName != null && !beanName.isEmpty()) {
String actualBeanName = getBeanNameForDataSource(ds);
return beanName.equals(actualBeanName);
}
return false;
})
.findFirst()
// If no dataSourceBeanName is specified, the first PostgreSQL data source is selected.
.orElseGet(() -> dataSources.stream()
.filter(this::isPostgresqlDataSource)
.findFirst()
.orElseThrow(() -> new IllegalStateException("No suitable PostgreSQL DataSource found in the application context. "
+ "Please configure a valid PostgreSQL DataSource.")));

log.info("Using DataSource bean: {}", dataSource.getClass().getSimpleName());

// Check if the context's data source is a Postgres datasource
ensureTrue(isPostgresqlDataSource(dataSource), "The DataSource in Spring Context is not a Postgres datasource, you need to manually specify the Postgres datasource configuration via 'langchain4j.pgvector.datasource'.");

Expand All @@ -58,11 +89,11 @@ public PgVectorEmbeddingStore pgVectorEmbeddingStoreWithCustomDataSource(PgVecto
Integer dimension = Optional.ofNullable(properties.getDimension()).orElseGet(() -> embeddingModel == null ? null : embeddingModel.dimension());

return PgVectorEmbeddingStore.builder()
.host(dataSourceProperties.getHost())
.port(dataSourceProperties.getPort())
.user(dataSourceProperties.getUser())
.password(dataSourceProperties.getPassword())
.database(dataSourceProperties.getDatabase())
.host(dataSourceProperties.host())
.port(dataSourceProperties.port())
.user(dataSourceProperties.user())
.password(dataSourceProperties.password())
.database(dataSourceProperties.database())
.table(properties.getTable())
.createTable(properties.getCreateTable())
.dimension(dimension)
Expand All @@ -85,4 +116,18 @@ private boolean isPostgresqlDataSource(DataSource dataSource) {
return false;
}
}

/**
* Get the BeanName of the DataSource instance from the ApplicationContext.
* @param dataSource Target DataSource instance.
* @return bean name of target DataSource .
*/
private String getBeanNameForDataSource(DataSource dataSource) {
// 遍历所有 DataSource Bean,找到与当前实例匹配的 Bean 名称
return applicationContext.getBeansOfType(DataSource.class).entrySet().stream()
.filter(entry -> entry.getValue().equals(dataSource))
.map(Map.Entry::getKey)
.findFirst()
.orElse(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public class PgVectorEmbeddingStoreProperties {
*/
private Integer indexListSize;

private String dataSourceBeanName;


public String getTable() {
return table;
}
Expand Down Expand Up @@ -71,4 +74,12 @@ public Integer getIndexListSize() {
public void setIndexListSize(Integer indexListSize) {
this.indexListSize = indexListSize;
}

public String getDataSourceBeanName() {
return dataSourceBeanName;
}

public void setDataSourceBeanName(String dataSourceBeanName) {
this.dataSourceBeanName = dataSourceBeanName;
}
}
Loading

0 comments on commit 875b24a

Please sign in to comment.