diff --git a/langchain4j-ollama-spring-boot-starter/pom.xml b/langchain4j-ollama-spring-boot-starter/pom.xml new file mode 100644 index 00000000..718b7c61 --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/pom.xml @@ -0,0 +1,80 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-spring + 0.26.0-SNAPSHOT + ../pom.xml + + + langchain4j-ollama-spring-boot-starter + LangChain4j Spring Boot starter for Ollama + jar + + + + + dev.langchain4j + langchain4j-ollama + ${project.version} + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.boot + spring-boot-autoconfigure-processor + true + + + + + org.projectlombok + lombok + provided + + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + junit-jupiter + test + + + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + \ No newline at end of file diff --git a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/AutoConfig.java b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/AutoConfig.java new file mode 100644 index 00000000..dff33a5d --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/AutoConfig.java @@ -0,0 +1,104 @@ +package dev.langchain4j.ollama.spring; + +import dev.langchain4j.model.ollama.*; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +import static dev.langchain4j.ollama.spring.Properties.PREFIX; + +@AutoConfiguration +@EnableConfigurationProperties(Properties.class) +public class AutoConfig { + + @Bean + @ConditionalOnProperty(PREFIX + ".chat-model.base-url") + OllamaChatModel ollamaChatModel(Properties properties) { + ChatModelProperties chatModelProperties = properties.getChatModel(); + return OllamaChatModel.builder() + .baseUrl(chatModelProperties.getBaseUrl()) + .modelName(chatModelProperties.getModelName()) + .temperature(chatModelProperties.getTemperature()) + .topK(chatModelProperties.getTopK()) + .topP(chatModelProperties.getTopP()) + .repeatPenalty(chatModelProperties.getRepeatPenalty()) + .seed(chatModelProperties.getSeed()) + .numPredict(chatModelProperties.getNumPredict()) + .stop(chatModelProperties.getStop()) + .format(chatModelProperties.getFormat()) + .timeout(chatModelProperties.getTimeout()) + .maxRetries(chatModelProperties.getMaxRetries()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url") + OllamaStreamingChatModel ollamaStreamingChatModel(Properties properties) { + ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); + return OllamaStreamingChatModel.builder() + .baseUrl(chatModelProperties.getBaseUrl()) + .modelName(chatModelProperties.getModelName()) + .temperature(chatModelProperties.getTemperature()) + .topK(chatModelProperties.getTopK()) + .topP(chatModelProperties.getTopP()) + .repeatPenalty(chatModelProperties.getRepeatPenalty()) + .seed(chatModelProperties.getSeed()) + .numPredict(chatModelProperties.getNumPredict()) + .stop(chatModelProperties.getStop()) + .format(chatModelProperties.getFormat()) + .timeout(chatModelProperties.getTimeout()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".language-model.base-url") + OllamaLanguageModel ollamaLanguageModel(Properties properties) { + LanguageModelProperties languageModelProperties = properties.getLanguageModel(); + return OllamaLanguageModel.builder() + .baseUrl(languageModelProperties.getBaseUrl()) + .modelName(languageModelProperties.getModelName()) + .temperature(languageModelProperties.getTemperature()) + .topK(languageModelProperties.getTopK()) + .topP(languageModelProperties.getTopP()) + .repeatPenalty(languageModelProperties.getRepeatPenalty()) + .seed(languageModelProperties.getSeed()) + .numPredict(languageModelProperties.getNumPredict()) + .stop(languageModelProperties.getStop()) + .format(languageModelProperties.getFormat()) + .timeout(languageModelProperties.getTimeout()) + .maxRetries(languageModelProperties.getMaxRetries()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url") + OllamaStreamingLanguageModel ollamaStreamingLanguageModel(Properties properties) { + LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel(); + return OllamaStreamingLanguageModel.builder() + .baseUrl(languageModelProperties.getBaseUrl()) + .modelName(languageModelProperties.getModelName()) + .temperature(languageModelProperties.getTemperature()) + .topK(languageModelProperties.getTopK()) + .topP(languageModelProperties.getTopP()) + .repeatPenalty(languageModelProperties.getRepeatPenalty()) + .seed(languageModelProperties.getSeed()) + .numPredict(languageModelProperties.getNumPredict()) + .stop(languageModelProperties.getStop()) + .format(languageModelProperties.getFormat()) + .timeout(languageModelProperties.getTimeout()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".embedding-model.base-url") + OllamaEmbeddingModel ollamaEmbeddingModel(Properties properties) { + EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel(); + return OllamaEmbeddingModel.builder() + .baseUrl(embeddingModelProperties.getBaseUrl()) + .modelName(embeddingModelProperties.getModelName()) + .timeout(embeddingModelProperties.getTimeout()) + .maxRetries(embeddingModelProperties.getMaxRetries()) + .build(); + } +} \ No newline at end of file diff --git a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/ChatModelProperties.java b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/ChatModelProperties.java new file mode 100644 index 00000000..4aed4397 --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/ChatModelProperties.java @@ -0,0 +1,25 @@ +package dev.langchain4j.ollama.spring; + +import lombok.Getter; +import lombok.Setter; + +import java.time.Duration; +import java.util.List; + +@Getter +@Setter +class ChatModelProperties { + + String baseUrl; + String modelName; + Double temperature; + Integer topK; + Double topP; + Double repeatPenalty; + Integer seed; + Integer numPredict; + List stop; + String format; + Duration timeout; + Integer maxRetries; +} \ No newline at end of file diff --git a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/EmbeddingModelProperties.java b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/EmbeddingModelProperties.java new file mode 100644 index 00000000..512ae78e --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/EmbeddingModelProperties.java @@ -0,0 +1,16 @@ +package dev.langchain4j.ollama.spring; + +import lombok.Getter; +import lombok.Setter; + +import java.time.Duration; + +@Getter +@Setter +class EmbeddingModelProperties { + + String baseUrl; + String modelName; + Duration timeout; + Integer maxRetries; +} \ No newline at end of file diff --git a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/LanguageModelProperties.java b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/LanguageModelProperties.java new file mode 100644 index 00000000..6346005c --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/LanguageModelProperties.java @@ -0,0 +1,25 @@ +package dev.langchain4j.ollama.spring; + +import lombok.Getter; +import lombok.Setter; + +import java.time.Duration; +import java.util.List; + +@Getter +@Setter +class LanguageModelProperties { + + String baseUrl; + String modelName; + Double temperature; + Integer topK; + Double topP; + Double repeatPenalty; + Integer seed; + Integer numPredict; + List stop; + String format; + Duration timeout; + Integer maxRetries; +} \ No newline at end of file diff --git a/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/Properties.java b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/Properties.java new file mode 100644 index 00000000..d122226f --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/Properties.java @@ -0,0 +1,29 @@ +package dev.langchain4j.ollama.spring; + +import lombok.Getter; +import lombok.Setter; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@Getter +@Setter +@ConfigurationProperties(prefix = Properties.PREFIX) +public class Properties { + + static final String PREFIX = "langchain4j.ollama"; + + @NestedConfigurationProperty + ChatModelProperties chatModel; + + @NestedConfigurationProperty + ChatModelProperties streamingChatModel; + + @NestedConfigurationProperty + LanguageModelProperties languageModel; + + @NestedConfigurationProperty + LanguageModelProperties streamingLanguageModel; + + @NestedConfigurationProperty + EmbeddingModelProperties embeddingModel; +} diff --git a/langchain4j-ollama-spring-boot-starter/src/main/resources/META-INF/spring.factories b/langchain4j-ollama-spring-boot-starter/src/main/resources/META-INF/spring.factories new file mode 100644 index 00000000..f4d12b60 --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/src/main/resources/META-INF/spring.factories @@ -0,0 +1 @@ +org.springframework.boot.autoconfigure.EnableAutoConfiguration=dev.langchain4j.ollama.spring.AutoConfig \ No newline at end of file diff --git a/langchain4j-ollama-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/langchain4j-ollama-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000..2aca9ab2 --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +dev.langchain4j.ollama.spring.AutoConfig \ No newline at end of file diff --git a/langchain4j-ollama-spring-boot-starter/src/test/java/dev/langchain4j/ollama/spring/AutoConfigIT.java b/langchain4j-ollama-spring-boot-starter/src/test/java/dev/langchain4j/ollama/spring/AutoConfigIT.java new file mode 100644 index 00000000..5a8fd250 --- /dev/null +++ b/langchain4j-ollama-spring-boot-starter/src/test/java/dev/langchain4j/ollama/spring/AutoConfigIT.java @@ -0,0 +1,167 @@ +package dev.langchain4j.ollama.spring; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.language.LanguageModel; +import dev.langchain4j.model.language.StreamingLanguageModel; +import dev.langchain4j.model.ollama.*; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.concurrent.CompletableFuture; + +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +@Testcontainers +class AutoConfigIT { + + private static final String MODEL_NAME = "phi"; + + @Container + static GenericContainer ollama = new GenericContainer<>("langchain4j/ollama-" + MODEL_NAME) + .withExposedPorts(11434); + + ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(AutoConfig.class)); + + private static String baseUrl() { + return format("http://%s:%s", ollama.getHost(), ollama.getFirstMappedPort()); + } + + @Test + void should_provide_chat_model() { + contextRunner + .withPropertyValues( + "langchain4j.ollama.chat-model.base-url=" + baseUrl(), + "langchain4j.ollama.chat-model.model-name=" + MODEL_NAME, + "langchain4j.ollama.chat-model.max-tokens=20", + "langchain4j.ollama.chat-model.temperature=0.0" + ) + .run(context -> { + + ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class); + assertThat(chatLanguageModel).isInstanceOf(OllamaChatModel.class); + assertThat(chatLanguageModel.generate("What is the capital of Germany?")).contains("Berlin"); + + assertThat(context.getBean(OllamaChatModel.class)).isSameAs(chatLanguageModel); + }); + } + + @Test + void should_provide_streaming_chat_model() { + contextRunner + .withPropertyValues( + "langchain4j.ollama.streaming-chat-model.base-url=" + baseUrl(), + "langchain4j.ollama.streaming-chat-model.model-name=" + MODEL_NAME, + "langchain4j.ollama.streaming-chat-model.max-tokens=20", + "langchain4j.ollama.streaming-chat-model.temperature=0.0" + ) + .run(context -> { + + StreamingChatLanguageModel streamingChatLanguageModel = context.getBean(StreamingChatLanguageModel.class); + assertThat(streamingChatLanguageModel).isInstanceOf(OllamaStreamingChatModel.class); + CompletableFuture> future = new CompletableFuture<>(); + streamingChatLanguageModel.generate("What is the capital of Germany?", new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + } + + @Override + public void onComplete(Response response) { + future.complete(response); + } + + @Override + public void onError(Throwable error) { + } + }); + Response response = future.get(30, SECONDS); + assertThat(response.content().text()).contains("Berlin"); + + assertThat(context.getBean(OllamaStreamingChatModel.class)).isSameAs(streamingChatLanguageModel); + }); + } + + @Test + void should_provide_language_model() { + contextRunner + .withPropertyValues( + "langchain4j.ollama.language-model.base-url=" + baseUrl(), + "langchain4j.ollama.language-model.model-name=" + MODEL_NAME, + "langchain4j.ollama.language-model.max-tokens=20", + "langchain4j.ollama.language-model.temperature=0.0" + ) + .run(context -> { + + LanguageModel languageModel = context.getBean(LanguageModel.class); + assertThat(languageModel).isInstanceOf(OllamaLanguageModel.class); + assertThat(languageModel.generate("What is the capital of Germany?").content()).contains("Berlin"); + + assertThat(context.getBean(OllamaLanguageModel.class)).isSameAs(languageModel); + }); + } + + @Test + void should_provide_streaming_language_model() { + contextRunner + .withPropertyValues( + "langchain4j.ollama.streaming-language-model.base-url=" + baseUrl(), + "langchain4j.ollama.streaming-language-model.model-name=" + MODEL_NAME, + "langchain4j.ollama.streaming-language-model.max-tokens=20", + "langchain4j.ollama.streaming-language-model.temperature=0.0" + ) + .run(context -> { + + StreamingLanguageModel streamingLanguageModel = context.getBean(StreamingLanguageModel.class); + assertThat(streamingLanguageModel).isInstanceOf(OllamaStreamingLanguageModel.class); + CompletableFuture> future = new CompletableFuture<>(); + streamingLanguageModel.generate("What is the capital of Germany?", new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + } + + @Override + public void onComplete(Response response) { + future.complete(response); + } + + @Override + public void onError(Throwable error) { + } + }); + Response response = future.get(30, SECONDS); + assertThat(response.content()).contains("Berlin"); + + assertThat(context.getBean(OllamaStreamingLanguageModel.class)).isSameAs(streamingLanguageModel); + }); + } + + @Test + void should_provide_embedding_model() { + contextRunner + .withPropertyValues( + "langchain4j.ollama.embedding-model.base-url=" + baseUrl(), + "langchain4j.ollama.embedding-model.model-name=" + MODEL_NAME + ) + .run(context -> { + + EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class); + assertThat(embeddingModel).isInstanceOf(OllamaEmbeddingModel.class); + assertThat(embeddingModel.embed("hi").content().dimension()).isEqualTo(2560); + + assertThat(context.getBean(OllamaEmbeddingModel.class)).isSameAs(embeddingModel); + }); + } +} \ No newline at end of file diff --git a/pom.xml b/pom.xml index 998c6c6f..67d85813 100644 --- a/pom.xml +++ b/pom.xml @@ -14,6 +14,7 @@ https://github.com/langchain4j/langchain4j-spring + langchain4j-ollama-spring-boot-starter langchain4j-open-ai-spring-boot-starter @@ -57,6 +58,14 @@ 1.18.30 + + org.testcontainers + testcontainers-bom + 1.19.2 + import + pom + +