Skip to content

Commit

Permalink
Changed version of DB image for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fmeheust committed Jan 10, 2025
1 parent 6f296be commit 3138115
Showing 1 changed file with 42 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
package dev.langchain4j.store.embedding.oracle;

import static org.assertj.core.api.Assertions.assertThat;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import oracle.jdbc.OracleConnection;
import oracle.sql.CHAR;
import oracle.sql.CharacterSet;
import oracle.ucp.jdbc.PoolDataSource;
import oracle.ucp.jdbc.PoolDataSourceFactory;
import org.testcontainers.oracle.OracleContainer;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -24,8 +18,13 @@
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;

import static org.assertj.core.api.Assertions.assertThat;
import javax.sql.DataSource;
import oracle.jdbc.OracleConnection;
import oracle.sql.CHAR;
import oracle.sql.CharacterSet;
import oracle.ucp.jdbc.PoolDataSource;
import oracle.ucp.jdbc.PoolDataSourceFactory;
import org.testcontainers.oracle.OracleContainer;

/**
* A collection of operations which are shared by tests in this package.
Expand All @@ -45,8 +44,9 @@ final class CommonTestOperations {
* Seed for random numbers. When a test fails, "-Ddev.langchain4j.store.embedding.oracle.SEED=..." can be used to
* re-execute it with the same random numbers.
*/
private static final long SEED = Long.getLong(
"dev.langchain4j.store.embedding.oracle.SEED", System.currentTimeMillis());
private static final long SEED =
Long.getLong("dev.langchain4j.store.embedding.oracle.SEED", System.currentTimeMillis());

static {
Logger.getLogger(CommonTestOperations.class.getName())
.info("dev.langchain4j.store.embedding.oracle.SEED=" + SEED);
Expand All @@ -62,7 +62,7 @@ private CommonTestOperations() {}
private static final PoolDataSource DATA_SOURCE = PoolDataSourceFactory.getPoolDataSource();
private static final PoolDataSource SYSDBA_DATA_SOURCE = PoolDataSourceFactory.getPoolDataSource();

public static final String ORACLE_IMAGE_NAME = "gvenzl/oracle-free:23.5-slim-faststart";
public static final String ORACLE_IMAGE_NAME = "gvenzl/oracle-free:23.6-slim-faststart";

static {
try {
Expand All @@ -71,27 +71,33 @@ private CommonTestOperations() {}

if (urlFromEnv == null) {
// The Ryuk component is relied upon to stop this container.
OracleContainer oracleContainer =
new OracleContainer(ORACLE_IMAGE_NAME)
OracleContainer oracleContainer = new OracleContainer(ORACLE_IMAGE_NAME)
.withStartupTimeout(Duration.ofSeconds(600))
.withConnectTimeoutSeconds(600)
.withDatabaseName("pdb1")
.withUsername("testuser")
.withPassword("testpwd");
oracleContainer.start();

initDataSource(DATA_SOURCE,
oracleContainer.getJdbcUrl(), oracleContainer.getUsername(), oracleContainer.getPassword());
initDataSource(SYSDBA_DATA_SOURCE,
oracleContainer.getJdbcUrl(), "sys", oracleContainer.getPassword());
initDataSource(
DATA_SOURCE,
oracleContainer.getJdbcUrl(),
oracleContainer.getUsername(),
oracleContainer.getPassword());
initDataSource(SYSDBA_DATA_SOURCE, oracleContainer.getJdbcUrl(), "sys", oracleContainer.getPassword());
} else {
initDataSource(DATA_SOURCE,
urlFromEnv, System.getenv("ORACLE_JDBC_USER"), System.getenv("ORACLE_JDBC_PASSWORD"));
initDataSource(SYSDBA_DATA_SOURCE,
urlFromEnv, System.getenv("ORACLE_JDBC_USER"), System.getenv("ORACLE_JDBC_PASSWORD"));
initDataSource(
DATA_SOURCE,
urlFromEnv,
System.getenv("ORACLE_JDBC_USER"),
System.getenv("ORACLE_JDBC_PASSWORD"));
initDataSource(
SYSDBA_DATA_SOURCE,
urlFromEnv,
System.getenv("ORACLE_JDBC_USER"),
System.getenv("ORACLE_JDBC_PASSWORD"));
}
SYSDBA_DATA_SOURCE.setConnectionProperty(OracleConnection.CONNECTION_PROPERTY_INTERNAL_LOGON,
"SYSDBA");
SYSDBA_DATA_SOURCE.setConnectionProperty(OracleConnection.CONNECTION_PROPERTY_INTERNAL_LOGON, "SYSDBA");

} catch (SQLException sqlException) {
throw new AssertionError(sqlException);
Expand All @@ -104,11 +110,9 @@ static void initDataSource(PoolDataSource dataSource, String url, String usernam
dataSource.setURL(url);
dataSource.setUser(username);
dataSource.setPassword(password);
} catch (
SQLException sqlException) {
} catch (SQLException sqlException) {
throw new AssertionError(sqlException);
}

}

static EmbeddingModel getEmbeddingModel() {
Expand All @@ -119,7 +123,9 @@ static DataSource getDataSource() {
return DATA_SOURCE;
}

static DataSource getSysDBADataSource() { return SYSDBA_DATA_SOURCE; }
static DataSource getSysDBADataSource() {
return SYSDBA_DATA_SOURCE;
}

/**
* Returns an embedding store configured to use a table with the common {@link #TABLE_NAME}. Any existing table
Expand Down Expand Up @@ -165,7 +171,7 @@ static void dropTable() throws SQLException {
*/
static void dropTable(String tableName) throws SQLException {
try (Connection connection = DATA_SOURCE.getConnection();
Statement statement = connection.createStatement()) {
Statement statement = connection.createStatement()) {
statement.addBatch("DROP INDEX IF EXISTS " + tableName + "_EMBEDDING_INDEX");
statement.addBatch("DROP TABLE IF EXISTS " + tableName);
statement.executeBatch();
Expand All @@ -178,8 +184,8 @@ static void dropTable(String tableName) throws SQLException {
*/
static CharacterSet getCharacterSet() throws SQLException {
try (Connection connection = CommonTestOperations.getDataSource().getConnection();
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery("SELECT 'c' FROM sys.dual")) {
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery("SELECT 'c' FROM sys.dual")) {
resultSet.next();
return resultSet.getObject(1, CHAR.class).getCharacterSet();
}
Expand All @@ -194,8 +200,7 @@ static CharacterSet getCharacterSet() throws SQLException {
static float[] randomFloats(int length) {
float[] floats = new float[length];

for (int i = 0; i < floats.length; i++)
floats[i] = RANDOM.nextFloat();
for (int i = 0; i < floats.length; i++) floats[i] = RANDOM.nextFloat();

return floats;
}
Expand All @@ -214,8 +219,7 @@ static void verifySearch(EmbeddingStore<TextSegment> embeddingStore) {
float[] vector1 = vector0.clone();

// Only higher indexes are increased in order to effect the cosine angle, and not just magnitude
for (int i = 0; i < vector1.length / 2; i++)
vector1[i] += 0.1f;
for (int i = 0; i < vector1.length / 2; i++) vector1[i] += 0.1f;

List<Embedding> embeddings = new ArrayList<>(2);
embeddings.add(Embedding.from(vector0));
Expand All @@ -231,9 +235,7 @@ static void verifySearch(EmbeddingStore<TextSegment> embeddingStore) {

// Verify the first vector is matched
EmbeddingMatch<TextSegment> match =
embeddingStore.search(request)
.matches()
.get(0);
embeddingStore.search(request).matches().get(0);
assertThat(match.embeddingId()).isEqualTo(ids.get(1));
assertThat(match.embedding().vector()).containsExactly(vector1);
}
Expand Down

0 comments on commit 3138115

Please sign in to comment.