From 076fd0590976e8d8ba3e8e25deabf2a5ba6cd706 Mon Sep 17 00:00:00 2001 From: David Jiang Date: Sat, 1 Feb 2025 17:01:22 -0500 Subject: [PATCH] Add doc loader --- langchain4j-oracle/pom.xml | 27 +++ .../loader/oracle/OracleDocumentLoader.java | 164 +++++++++++++++++ .../data/document/splitter/oracle/Chunk.java | 13 ++ .../oracle/OracleDocumentSplitter.java | 101 ++++++++++ .../langchain4j/model/oracle/Embedding.java | 11 ++ .../model/oracle/OracleEmbeddingModel.java | 172 ++++++++++++++++++ .../oracle/OracleSummaryLanguageModel.java | 67 +++++++ .../loader/OracleDocumentLoaderTest.java | 89 +++++++++ .../oracle/OracleDocumentSplitterTest.java | 115 ++++++++++++ .../oracle/OracleEmbeddingModelTest.java | 98 ++++++++++ .../model/oracle/OracleIngestTest.java | 127 +++++++++++++ .../OracleSummaryLanguageModelTest.java | 89 +++++++++ 12 files changed, 1073 insertions(+) create mode 100644 langchain4j-oracle/src/main/java/dev/langchain4j/data/document/loader/oracle/OracleDocumentLoader.java create mode 100644 langchain4j-oracle/src/main/java/dev/langchain4j/data/document/splitter/oracle/Chunk.java create mode 100644 langchain4j-oracle/src/main/java/dev/langchain4j/data/document/splitter/oracle/OracleDocumentSplitter.java create mode 100644 langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/Embedding.java create mode 100644 langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/OracleEmbeddingModel.java create mode 100644 langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/OracleSummaryLanguageModel.java create mode 100644 langchain4j-oracle/src/test/java/dev/langchain4j/data/document/loader/OracleDocumentLoaderTest.java create mode 100644 langchain4j-oracle/src/test/java/dev/langchain4j/data/document/splitter/oracle/OracleDocumentSplitterTest.java create mode 100644 langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleEmbeddingModelTest.java create mode 100644 langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleIngestTest.java create mode 100644 langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleSummaryLanguageModelTest.java diff --git a/langchain4j-oracle/pom.xml b/langchain4j-oracle/pom.xml index 7555c794383..70ea7fa9653 100755 --- a/langchain4j-oracle/pom.xml +++ b/langchain4j-oracle/pom.xml @@ -112,6 +112,33 @@ test + + + org.jsoup + jsoup + 1.18.1 + + + com.fasterxml.jackson.core + jackson-core + 2.16.1 + + + com.fasterxml.jackson.core + jackson-databind + 2.16.1 + + + com.fasterxml.jackson.core + jackson-annotations + 2.16.1 + + + io.github.cdimascio + dotenv-java + 3.0.0 + + diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/loader/oracle/OracleDocumentLoader.java b/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/loader/oracle/OracleDocumentLoader.java new file mode 100644 index 00000000000..92adfb24909 --- /dev/null +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/loader/oracle/OracleDocumentLoader.java @@ -0,0 +1,164 @@ +package dev.langchain4j.data.document.loader.oracle; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.Metadata; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.InvalidParameterException; +import java.sql.Blob; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import org.jsoup.Jsoup; +import org.jsoup.nodes.Element; +import org.jsoup.select.Elements; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OracleDocumentLoader { + + private static final Logger log = LoggerFactory.getLogger(OracleDocumentLoader.class); + + private final Connection conn; + + public OracleDocumentLoader(Connection conn) { + this.conn = conn; + } + + public List loadDocuments(String pref) throws JsonProcessingException, IOException, SQLException { + List documents = new ArrayList<>(); + + ObjectMapper mapper = new ObjectMapper(); + JsonNode rootNode = mapper.readTree(pref); + JsonNode fileNode = rootNode.path("file"); + JsonNode dirNode = rootNode.path("dir"); + JsonNode ownerNode = rootNode.path("owner"); + JsonNode tableNode = rootNode.path("tablename"); + JsonNode colNode = rootNode.path("colname"); + + if (fileNode.textValue() != null) { + String filename = fileNode.textValue(); + Document doc = loadDocument(filename, pref); + if (doc != null) { + documents.add(doc); + } + } else if (dirNode.textValue() != null) { + String dir = dirNode.textValue(); + Path root = Paths.get(dir); + Files.walk(root).forEach(path -> { + if (path.toFile().isFile()) { + Document doc = null; + try { + doc = loadDocument(path.toFile().toString(), pref); + } catch (IOException | SQLException e) { + String message = e.getCause() != null ? e.getCause().getMessage() : e.getMessage(); + log.warn("Failed to summarize '{}': {}", pref, message); + } + if (doc != null) { + documents.add(doc); + } + } + }); + } else if (colNode.textValue() != null) { + String column = colNode.textValue(); + + String table = tableNode.textValue(); + String owner = ownerNode.textValue(); + if (table == null) { + throw new InvalidParameterException("Missing table in preference"); + } + if (owner == null) { + throw new InvalidParameterException("Missing owner in preference"); + } + + documents.addAll(loadDocuments(owner, table, column, pref)); + } else { + throw new InvalidParameterException("Missing file, dir, or table in preference"); + } + + return documents; + } + + private Document loadDocument(String filename, String pref) throws IOException, SQLException { + Document document = null; + + byte[] bytes = Files.readAllBytes(Paths.get(filename)); + + String query = "select dbms_vector_chain.utl_to_text(?, json(?)) text, dbms_vector_chain.utl_to_text(?, json('{\"plaintext\": \"false\"}')) metadata from dual"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + Blob blob = conn.createBlob(); + blob.setBytes(1, bytes); + + stmt.setBlob(1, blob); + stmt.setObject(2, pref); + stmt.setBlob(3, blob); + + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + String text = rs.getString("text"); + String html = rs.getString("metadata"); + + Metadata metadata = getMetadata(html); + Path path = Paths.get(filename); + metadata.put(Document.FILE_NAME, path.getFileName().toString()); + metadata.put(Document.ABSOLUTE_DIRECTORY_PATH, path.getParent().toString()); + document = Document.from(text, metadata); + } + } + } + + return document; + } + + private List loadDocuments(String owner, String table, String column, String pref) throws SQLException { + List documents = new ArrayList<>(); + + String query = String.format("select dbms_vector_chain.utl_to_text(t.%s, json(?)) text, dbms_vector_chain.utl_to_text(t.%s, json('{\"plaintext\": \"false\"}')) metadata from %s.%s t", + column, column, owner, table); + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setObject(1, pref); + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + String text = rs.getString("text"); + String html = rs.getString("metadata"); + + Metadata metadata = getMetadata(html); + Document doc = Document.from(text, metadata); + documents.add(doc); + } + } + } + + return documents; + } + + private static Metadata getMetadata(String html) { + Metadata metadata = new Metadata(); + + org.jsoup.nodes.Document doc = Jsoup.parse(html); + Elements metaTags = doc.getElementsByTag("meta"); + for (Element metaTag : metaTags) { + String name = metaTag.attr("name"); + if (name.isEmpty()) { + continue; + } + String content = metaTag.attr("content"); + metadata.put(name, content); + } + + return metadata; + } +} diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/splitter/oracle/Chunk.java b/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/splitter/oracle/Chunk.java new file mode 100644 index 00000000000..c85b48d4898 --- /dev/null +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/splitter/oracle/Chunk.java @@ -0,0 +1,13 @@ +package dev.langchain4j.data.document.splitter.oracle; + +public class Chunk { + + public int chunk_id; + public int chunk_offset; + public int chunk_length; + public String chunk_data; + + public Chunk() { + } + +} diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/splitter/oracle/OracleDocumentSplitter.java b/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/splitter/oracle/OracleDocumentSplitter.java new file mode 100644 index 00000000000..eef1efa524c --- /dev/null +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/data/document/splitter/oracle/OracleDocumentSplitter.java @@ -0,0 +1,101 @@ +package dev.langchain4j.data.document.splitter.oracle; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.DocumentSplitter; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OracleDocumentSplitter implements DocumentSplitter { + + private static final Logger log = LoggerFactory.getLogger(OracleDocumentSplitter.class); + + private static final String INDEX = "index"; + + private final Connection conn; + private final String pref; + + public OracleDocumentSplitter(Connection conn, String pref) { + this.conn = conn; + this.pref = pref; + } + + @Override + public List split(Document document) { + List segments = new ArrayList<>(); + try { + String[] parts = split(document.text()); + int index = 0; + for (String part : parts) { + segments.add(createSegment(part, document, index)); + index++; + } + } catch (SQLException | JsonProcessingException e) { + String message = e.getCause() != null ? e.getCause().getMessage() : e.getMessage(); + log.warn("Failed to summarize '{}': {}", pref, message); + } + return segments; + } + + @Override + public List splitAll(List list) { + return DocumentSplitter.super.splitAll(list); + } + + /** + * Splits the provided text into parts. Implementation API. + * + * @param content The text to be split. + * @return An array of parts. + */ + public String[] split(String content) throws SQLException, JsonProcessingException { + + List strArr = new ArrayList<>(); + + String query = "select t.column_value as data from dbms_vector_chain.utl_to_chunks(?, json(?)) t"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setObject(1, content); + stmt.setObject(2, pref); + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + String text = rs.getString("data"); + + ObjectMapper mapper = new ObjectMapper(); + Chunk chunk = mapper.readValue(text, Chunk.class); + strArr.add(chunk.chunk_data); + } + } + } + + return strArr.toArray(new String[strArr.size()]); + } + + /** + * Creates a new {@link TextSegment} from the provided text and document. + * + *

+ * The segment inherits all metadata from the document. The segment also + * includes an "index" metadata key representing the segment position within + * the document. + * + * @param text The text of the segment. + * @param document The document to which the segment belongs. + * @param index The index of the segment within the document. + */ + static TextSegment createSegment(String text, Document document, int index) { + Metadata metadata = document.metadata().copy().put(INDEX, String.valueOf(index)); + return TextSegment.from(text, metadata); + } +} diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/Embedding.java b/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/Embedding.java new file mode 100644 index 00000000000..96eba8ebb01 --- /dev/null +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/Embedding.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.oracle; + +public class Embedding { + + public int embed_id; + public String embed_data; + public String embed_vector; + + public Embedding() { + } +} diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/OracleEmbeddingModel.java b/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/OracleEmbeddingModel.java new file mode 100644 index 00000000000..74fabed68fb --- /dev/null +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/OracleEmbeddingModel.java @@ -0,0 +1,172 @@ +package dev.langchain4j.model.oracle; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import dev.langchain4j.data.document.splitter.oracle.Chunk; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel; +import dev.langchain4j.model.output.Response; + +import java.sql.Array; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import static java.util.stream.Collectors.toList; + +import oracle.jdbc.OracleConnection; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OracleEmbeddingModel extends DimensionAwareEmbeddingModel { + + private static final Logger log = LoggerFactory.getLogger(OracleEmbeddingModel.class); + + private final Connection conn; + private final String pref; + private final String proxy; + private boolean batching = true; + + public OracleEmbeddingModel(Connection conn, String pref) { + this.conn = conn; + this.pref = pref; + this.proxy = ""; + } + + public OracleEmbeddingModel(Connection conn, String pref, String proxy) { + this.conn = conn; + this.pref = pref; + this.proxy = proxy; + } + + void setBatching(boolean batching) { + this.batching = batching; + } + + boolean getBatching() { + return this.batching; + } + + static boolean loadOnnxModel(Connection conn, String dir, String onnxFile, String modelName) throws SQLException { + boolean result = false; + + String query = "begin\n" + + " dbms_data_mining.drop_model(?, force => true);\n" + + " dbms_vector.load_onnx_model(?, ?, ?,\n" + + " json('{\"function\" : \"embedding\", \"embeddingOutput\" : \"embedding\" , \"input\": {\"input\": [\"DATA\"]}}'));\n" + + "end;"; + PreparedStatement stmt = conn.prepareStatement(query); + stmt.setObject(1, modelName); + stmt.setObject(2, dir); + stmt.setObject(3, onnxFile); + stmt.setObject(4, modelName); + stmt.execute(); + result = true; + + return result; + } + + @Override + public Response> embedAll(List textSegments) { + List texts = textSegments.stream() + .map(TextSegment::text) + .collect(toList()); + + return embedTexts(texts); + } + + private Response> embedTexts(List inputs) { + List embeddings = new ArrayList<>(); + + try { + if (proxy != null && !proxy.isEmpty()) { + String query = "begin utl_http.set_proxy(?); end;"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setObject(1, proxy); + stmt.execute(); + } + } + + if (!batching) { + for (String input : inputs) { + String query = "select t.column_value as data from dbms_vector_chain.utl_to_embeddings(?, json(?)) t"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setObject(1, input); + stmt.setObject(2, pref); + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + String text = rs.getString("data"); + + ObjectMapper mapper = new ObjectMapper(); + dev.langchain4j.model.oracle.Embedding dbmsEmbedding = mapper.readValue(text, dev.langchain4j.model.oracle.Embedding.class); + Embedding embedding = new Embedding(toFloatArray(dbmsEmbedding.embed_vector)); + embeddings.add(embedding); + } + } + } + } + } else { + // createOracleArray needs to passed a Clob array since vector_array_t is a table of clob + // if a String array is passed, will get ORA-17059: Failed to convert to internal representation + List elements = toClobList(conn, inputs); + Array arr = ((OracleConnection) conn).createOracleArray("SYS.VECTOR_ARRAY_T", elements.toArray()); + + String query = "select t.column_value as data from dbms_vector_chain.utl_to_embeddings(?, json(?)) t"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setObject(1, arr); + stmt.setObject(2, pref); + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + String text = rs.getString("data"); + + ObjectMapper mapper = new ObjectMapper(); + dev.langchain4j.model.oracle.Embedding dbmsEmbedding = mapper.readValue(text, dev.langchain4j.model.oracle.Embedding.class); + Embedding embedding = new Embedding(toFloatArray(dbmsEmbedding.embed_vector)); + embeddings.add(embedding); + + } + } + } + } + } catch (SQLException | JsonProcessingException e) { + String message = e.getCause() != null ? e.getCause().getMessage() : e.getMessage(); + log.warn("Failed to summarize '{}': {}", pref, message); + } + + return Response.from(embeddings); + } + + private List toClobList(Connection conn, List inputs) throws JsonProcessingException, SQLException { + ObjectMapper objectMapper = new ObjectMapper(); + + List chunks = new ArrayList<>(); + for (int i = 0; i < inputs.size(); i++) { + // Create JSON string + Chunk chunk = new Chunk(); + chunk.chunk_id = i; + chunk.chunk_data = inputs.get(i); + String jsonString = objectMapper.writeValueAsString(chunk); + + Clob clob = conn.createClob(); + clob.setString(1, jsonString); + chunks.add(clob); + } + return chunks; + } + + private float[] toFloatArray(String embedding) { + String str = embedding.replace("[", "").replace("]", ""); + String[] strArr = str.split(","); + float[] floatArr = new float[strArr.length]; + for (int i = 0; i < strArr.length; i++) { + floatArr[i] = Float.parseFloat(strArr[i]); + } + return floatArr; + } +} diff --git a/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/OracleSummaryLanguageModel.java b/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/OracleSummaryLanguageModel.java new file mode 100644 index 00000000000..34eaf8b95cb --- /dev/null +++ b/langchain4j-oracle/src/main/java/dev/langchain4j/model/oracle/OracleSummaryLanguageModel.java @@ -0,0 +1,67 @@ +package dev.langchain4j.model.oracle; + +import dev.langchain4j.model.language.LanguageModel; +import dev.langchain4j.model.output.Response; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OracleSummaryLanguageModel implements LanguageModel { + + private static final Logger log = LoggerFactory.getLogger(OracleSummaryLanguageModel.class); + + private final Connection conn; + private final String pref; + private final String proxy; + + public OracleSummaryLanguageModel(Connection conn, String pref) { + this.conn = conn; + this.pref = pref; + this.proxy = ""; + } + + public OracleSummaryLanguageModel(Connection conn, String pref, String proxy) { + this.conn = conn; + this.pref = pref; + this.proxy = proxy; + } + + @Override + public Response generate(String prompt) { + + String text = ""; + + try { + if (proxy != null && !proxy.isEmpty()) { + String query = "begin utl_http.set_proxy(?); end;"; + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setObject(1, proxy); + stmt.execute(); + } + } + + String query = "select dbms_vector_chain.utl_to_summary(?, json(?)) data from dual"; + + try (PreparedStatement stmt = conn.prepareStatement(query)) { + stmt.setObject(1, prompt); + stmt.setObject(2, pref); + + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + text = rs.getString("data"); + } + } + } + } catch (SQLException e) { + String message = e.getCause() != null ? e.getCause().getMessage() : e.getMessage(); + log.warn("Failed to summarize '{}': {}", pref, message); + } + + return Response.from(text); + } +} diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/data/document/loader/OracleDocumentLoaderTest.java b/langchain4j-oracle/src/test/java/dev/langchain4j/data/document/loader/OracleDocumentLoaderTest.java new file mode 100644 index 00000000000..7acc9610b8f --- /dev/null +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/data/document/loader/OracleDocumentLoaderTest.java @@ -0,0 +1,89 @@ +package dev.langchain4j.data.document.loader; + +import dev.langchain4j.data.document.loader.oracle.OracleDocumentLoader; +import dev.langchain4j.data.document.Document; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; +import org.slf4j.LoggerFactory; + +import io.github.cdimascio.dotenv.Dotenv; + +public class OracleDocumentLoaderTest { + + private static final org.slf4j.Logger log = LoggerFactory.getLogger(OracleDocumentLoaderTest.class); + + Dotenv dotenv; + OracleDocumentLoader loader; + + @BeforeEach + void setUp() { + dotenv = Dotenv.configure().load(); + + try { + Connection conn = DriverManager.getConnection( + dotenv.get("ORACLE_JDBC_URL"), dotenv.get("ORACLE_JDBC_USER"), dotenv.get("ORACLE_JDBC_PASSWORD")); + loader = new OracleDocumentLoader(conn); + } catch (SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("load from file") + void testFile() { + try { + String pref = "{\"file\": \"" + dotenv.get("DEMO_DS_PDF_FILE") + "\"}"; + List docs = loader.loadDocuments(pref); + assertThat(docs.size()).isEqualTo(1); + for (Document doc : docs) { + assertThat(doc.text().length()).isGreaterThan(0); + } + } catch (IOException | SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("load from dir") + void testDir() { + try { + String pref = "{\"dir\": \"" + dotenv.get("DEMO_DS_DIR") + "\"}"; + List docs = loader.loadDocuments(pref); + assertThat(docs.size()).isGreaterThan(1); + for (Document doc : docs) { + assertThat(doc.text().length()).isGreaterThan(0); + } + } catch (IOException | SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("load from table") + void testTable() { + try { + String pref = "{\"owner\": \"" + dotenv.get("DEMO_DS_OWNER") + "\", \"tablename\": \"" + dotenv.get("DEMO_DS_TABLE") + "\", \"colname\": \"" + dotenv.get("DEMO_DS_COLUMN") + "\"}"; + List docs = loader.loadDocuments(pref); + assertThat(docs.size()).isGreaterThan(1); + for (Document doc : docs) { + assertThat(doc.text().length()).isGreaterThan(0); + } + } catch (IOException | SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + +} diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/data/document/splitter/oracle/OracleDocumentSplitterTest.java b/langchain4j-oracle/src/test/java/dev/langchain4j/data/document/splitter/oracle/OracleDocumentSplitterTest.java new file mode 100644 index 00000000000..55ad4d9c277 --- /dev/null +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/data/document/splitter/oracle/OracleDocumentSplitterTest.java @@ -0,0 +1,115 @@ +package dev.langchain4j.data.document.splitter.oracle; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; +import static org.assertj.core.api.Assertions.assertThat; + +import io.github.cdimascio.dotenv.Dotenv; + +public class OracleDocumentSplitterTest { + + private static final org.slf4j.Logger log = LoggerFactory.getLogger(OracleDocumentSplitterTest.class); + + Dotenv dotenv; + Connection conn; + + @BeforeEach + void setUp() { + dotenv = Dotenv.configure().load(); + + try { + conn = DriverManager.getConnection( + dotenv.get("ORACLE_JDBC_URL"), dotenv.get("ORACLE_JDBC_USER"), dotenv.get("ORACLE_JDBC_PASSWORD")); + } catch (SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("split string input by chars") + void testByChars() { + String pref = "{\"by\": \"chars\", \"max\": 50}"; + String filename = dotenv.get("DEMO_DS_TEXT_FILE"); + + try { + OracleDocumentSplitter splitter = new OracleDocumentSplitter(conn, pref); + + String content = readFile(filename, Charset.forName("UTF-8")); + String[] chunks = splitter.split(content); + assertThat(chunks.length).isGreaterThan(1); + } catch (IOException | SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("split string input by words") + void testByWords() { + String pref = "{\"by\": \"words\", \"max\": 50}"; + String filename = dotenv.get("DEMO_DS_TEXT_FILE");; + + try { + OracleDocumentSplitter splitter = new OracleDocumentSplitter(conn, pref); + + String content = readFile(filename, Charset.forName("UTF-8")); + String[] chunks = splitter.split(content); + assertThat(chunks.length).isGreaterThan(1); + } catch (IOException | SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("split Doc input by chars") + void testDocByChars() { + String pref = "{\"by\": \"chars\", \"max\": 50}"; + String filename = dotenv.get("DEMO_DS_TEXT_FILE"); + + try { + OracleDocumentSplitter splitter = new OracleDocumentSplitter(conn, pref); + + String content = readFile(filename, Charset.forName("UTF-8")); + + // Create a document with some metadata + Metadata metadata = new Metadata(); + metadata.put("a", 1); + metadata.put("b", 2); + Document document = Document.from(content, metadata); + + List chunks = splitter.split(document); + assertThat(chunks.size()).isGreaterThan(1); + + // Check that the metadata was passed + TextSegment chunk = chunks.get(0); + int a = chunk.metadata().getInteger("a"); + assertThat(a).isEqualTo(1); + } catch (IOException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + static String readFile(String path, Charset encoding) + throws IOException { + byte[] bytes = Files.readAllBytes(Paths.get(path)); + return new String(bytes, encoding); + } +} diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleEmbeddingModelTest.java b/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleEmbeddingModelTest.java new file mode 100644 index 00000000000..fd584b06a2e --- /dev/null +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleEmbeddingModelTest.java @@ -0,0 +1,98 @@ +package dev.langchain4j.model.oracle; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; +import org.slf4j.LoggerFactory; + +import io.github.cdimascio.dotenv.Dotenv; + +public class OracleEmbeddingModelTest { + + private static final org.slf4j.Logger log = LoggerFactory.getLogger(OracleEmbeddingModelTest.class); + + Dotenv dotenv; + Connection conn; + + @BeforeEach + void setUp() { + dotenv = Dotenv.configure().load(); + + try { + conn = DriverManager.getConnection( + dotenv.get("ORACLE_JDBC_URL"), dotenv.get("ORACLE_JDBC_USER"), dotenv.get("ORACLE_JDBC_PASSWORD")); + } catch (SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("embed with provider=database") + void testEmbedONNX() { + try { + String pref = "{\"provider\": \"database\", \"model\": \"" + dotenv.get("DEMO_ONNX_MODEL") + "\"}"; + + OracleEmbeddingModel embedder = new OracleEmbeddingModel(conn, pref); + + boolean result = OracleEmbeddingModel.loadOnnxModel(conn, dotenv.get("DEMO_ONNX_DIR"), dotenv.get("DEMO_ONNX_FILE"), dotenv.get("DEMO_ONNX_MODEL")); + assertThat(result).isEqualTo(true); + + Response resp = embedder.embed("hello world"); + assertThat(resp.content().dimension()).isGreaterThan(1); + + TextSegment segment = TextSegment.from("hello world"); + Response resp2 = embedder.embed(segment); + assertThat(resp2.content().dimension()).isGreaterThan(1); + + List textSegments = new ArrayList<>(); + textSegments.add(TextSegment.from("hello world")); + textSegments.add(TextSegment.from("goodbye world")); + textSegments.add(TextSegment.from("1,2,3")); + Response> resp3 = embedder.embedAll(textSegments); + assertThat(resp3.content().size()).isEqualTo(3); + } catch (SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("embed with provider=ocigenai") + void testEmbedOcigenai() { + String pref = "{\n" + + " \"provider\": \"ocigenai\",\n" + + " \"credential_name\": \"OCI_CRED\",\n" + + " \"url\": \"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com/20231130/actions/embedText\",\n" + + " \"model\": \"cohere.embed-english-light-v3.0\"\n" + + "}"; + String proxy = dotenv.get("DEMO_PROXY"); + + OracleEmbeddingModel embedder = new OracleEmbeddingModel(conn, pref, proxy); + + Response resp = embedder.embed("hello world"); + assertThat(resp.content().dimension()).isGreaterThan(1); + + TextSegment segment = TextSegment.from("hello world"); + Response resp2 = embedder.embed(segment); + assertThat(resp2.content().dimension()).isGreaterThan(1); + + List textSegments = new ArrayList<>(); + textSegments.add(TextSegment.from("hello world")); + textSegments.add(TextSegment.from("goodbye world")); + textSegments.add(TextSegment.from("1,2,3")); + Response> resp3 = embedder.embedAll(textSegments); + assertThat(resp3.content().size()).isEqualTo(3); + } +} diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleIngestTest.java b/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleIngestTest.java new file mode 100644 index 00000000000..d555cf9ae48 --- /dev/null +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleIngestTest.java @@ -0,0 +1,127 @@ +package dev.langchain4j.model.oracle; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; +import dev.langchain4j.data.document.loader.oracle.OracleDocumentLoader; +import dev.langchain4j.data.document.splitter.oracle.OracleDocumentSplitter; +import static dev.langchain4j.store.embedding.oracle.CreateOption.CREATE_OR_REPLACE; +import dev.langchain4j.store.embedding.oracle.EmbeddingTable; +import dev.langchain4j.store.embedding.oracle.OracleEmbeddingStore; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; +import org.slf4j.LoggerFactory; + +import io.github.cdimascio.dotenv.Dotenv; + +public class OracleIngestTest { + + private static final org.slf4j.Logger log = LoggerFactory.getLogger(OracleEmbeddingModelTest.class); + + Dotenv dotenv; + Connection conn; + + @BeforeEach + void setUp() { + dotenv = Dotenv.configure().load(); + + try { + conn = DriverManager.getConnection( + dotenv.get("ORACLE_JDBC_URL"), dotenv.get("ORACLE_JDBC_USER"), dotenv.get("ORACLE_JDBC_PASSWORD")); + } catch (SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("ingest") + void testIngest() { + try { + String embedderPref = "{\"provider\": \"database\", \"model\": \"" + dotenv.get("DEMO_ONNX_MODEL") + "\"}"; + String splitterPref = "{\"by\": \"chars\", \"max\": 50}"; + + OracleDocumentLoader loader = new OracleDocumentLoader(conn); + OracleEmbeddingModel embedder = new OracleEmbeddingModel(conn, embedderPref); + OracleDocumentSplitter splitter = new OracleDocumentSplitter(conn, splitterPref); + + oracle.jdbc.datasource.OracleDataSource ods + = new oracle.jdbc.datasource.impl.OracleDataSource(); + ods.setURL(dotenv.get("ORACLE_JDBC_URL")); + ods.setUser(dotenv.get("ORACLE_JDBC_USER")); + ods.setPassword(dotenv.get("ORACLE_JDBC_PASSWORD")); + + // output table + String tableName = "TEST"; + String idColumn = "ID"; + String embeddingColumn = "EMBEDDING"; + String textColumn = "TEXT"; + String metadataColumn = "METADATA"; + + // The call to build() should create a table with the configured names + OracleEmbeddingStore embeddingStore + = OracleEmbeddingStore.builder() + .dataSource(ods) + .embeddingTable(EmbeddingTable.builder() + .createOption(CREATE_OR_REPLACE) + .name(tableName) + .idColumn(idColumn) + .embeddingColumn(embeddingColumn) + .textColumn(textColumn) + .metadataColumn(metadataColumn) + .build()) + .build(); + + boolean result = OracleEmbeddingModel.loadOnnxModel(conn, dotenv.get("DEMO_ONNX_DIR"), dotenv.get("DEMO_ONNX_FILE"), dotenv.get("DEMO_ONNX_MODEL")); + + String loaderPref = "{\"file\": \"" + dotenv.get("DEMO_DS_PDF_FILE") + "\"}"; + List docs = loader.loadDocuments(loaderPref); + + EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder() + .documentSplitter(splitter) + .embeddingModel(embedder) + .embeddingStore(embeddingStore) + .build(); + ingestor.ingest(docs); + + /* + // for debugging. ingest output should match + System.out.println("# docs=" + docs.size()); + List splits = splitter.splitAll(docs); + System.out.println("# split=" + splits.size()); + Response> embeddings = embedder.embedAll(splits); + System.out.println("# embedded=" + embeddings.content().size()); + */ + + int count = getCount(tableName); + assertThat(count).isGreaterThan(0); + } catch (SQLException | IOException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + int getCount(String tableName) throws SQLException { + int count = 0; + String query = "select count(*) from " + tableName; + PreparedStatement stmt = conn.prepareStatement(query); + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + count = rs.getInt(1); + } + } + return count; + } +} diff --git a/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleSummaryLanguageModelTest.java b/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleSummaryLanguageModelTest.java new file mode 100644 index 00000000000..8ec40440be8 --- /dev/null +++ b/langchain4j-oracle/src/test/java/dev/langchain4j/model/oracle/OracleSummaryLanguageModelTest.java @@ -0,0 +1,89 @@ +package dev.langchain4j.model.oracle; + +import dev.langchain4j.model.output.Response; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; +import static org.assertj.core.api.Assertions.assertThat; + +import io.github.cdimascio.dotenv.Dotenv; +import io.github.cdimascio.dotenv.DotenvException; + +public class OracleSummaryLanguageModelTest { + + private static final org.slf4j.Logger log = LoggerFactory.getLogger(OracleSummaryLanguageModelTest.class); + + Dotenv dotenv; + Connection conn; + + @BeforeEach + void setUp() { + dotenv = Dotenv.configure().load(); + + try { + conn = DriverManager.getConnection( + dotenv.get("ORACLE_JDBC_URL"), dotenv.get("ORACLE_JDBC_USER"), dotenv.get("ORACLE_JDBC_PASSWORD")); + } catch (SQLException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("summary with provider=database") + void testSummaryDatabase() { + try { + String pref = "{\"provider\": \"database\", \"gLevel\": \"S\"}"; + + OracleSummaryLanguageModel model = new OracleSummaryLanguageModel(conn, pref); + + String filename = dotenv.get("DEMO_DS_TEXT_FILE"); + String content = readFile(filename, Charset.forName("UTF-8")); + Response resp = model.generate(content); + assertThat(resp.content().length()).isGreaterThan(0); + } catch (IOException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + @Test + @DisplayName("summary with provider=OCIGenAI") + void testSummaryOcigenai() { + try { + String pref = "{\n" + + " \"provider\": \"ocigenai\",\n" + + " \"credential_name\": \"OCI_CRED\",\n" + + " \"url\": \"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com/20231130/actions/chat\",\n" + + " \"model\": \"cohere.command-r-16k\",\n" + + "}"; + String proxy = dotenv.get("DEMO_PROXY"); + + OracleSummaryLanguageModel model = new OracleSummaryLanguageModel(conn, pref, proxy); + + String filename = dotenv.get("DEMO_DS_TEXT_FILE"); + String content = readFile(filename, Charset.forName("UTF-8")); + Response resp = model.generate(content); + assertThat(resp.content().length()).isGreaterThan(0); + } catch (IOException ex) { + String message = ex.getCause() != null ? ex.getCause().getMessage() : ex.getMessage(); + log.error(message); + } + } + + static String readFile(String path, Charset encoding) + throws IOException { + byte[] bytes = Files.readAllBytes(Paths.get(path)); + return new String(bytes, encoding); + } +}