diff --git a/.idea/gradle.xml b/.idea/gradle.xml index 3bba60c..056658a 100644 --- a/.idea/gradle.xml +++ b/.idea/gradle.xml @@ -14,7 +14,9 @@ diff --git a/docs-site/docs/03_components/03_text_generation.md b/docs-site/docs/03_components/03_text_generation.md index 747f88b..53aa1f1 100644 --- a/docs-site/docs/03_components/03_text_generation.md +++ b/docs-site/docs/03_components/03_text_generation.md @@ -1,27 +1,28 @@ # Text generation module -The text generation module aims to give the user an abstraction to use any model to give final answers to user questions. +The text generation module provides an abstraction to use Large Language Models (LLMs) that, given a prompt, complete it with the most useful text possible. LLMs can be used for a very wide range of use cases, including summarizing text, drafting an essay, or getting answers for general knowledge questions. -Below, you can find the current implementations available. +eLLMental provides the following implementations: -## OpenAI implementation +## `OpenAiChatGenerationModel` -You can use the `OpenAiTextGenerationService` to work with openAI chat messaging system: +The `OpenAiChatGenerationModel` is pre-configured to simulate a chat conversation. In this case, the model accepts a list of `ChatMessage` objects, where each message has a `user` and a `text` field. Then, the model will generate a response to the last message in the list. If you have used [ChatGPT](https://openai.com/chatgpt) before, the behavior will be very similar. ```java -Double temperature = 0.1; -int maxTokens = 3000; ChatMessage chatMessage = new ChatMessage("user", "what is an LLM?"); -TextGenerationService, OpenAiModels> service = new OpenAiTextGenerationService("api key"); -String result = service.generate(List.of(chatMessage), temperature, maxTokens, OpenAiModels.GPT_3_5); +// Using the simplified constructor, the library will set sensible defaults for the temperature, maxTokens and model. +OpenAiChatGenerationModel model = new OpenAiChatGenerationModel(""); +String response = service.generate(List.of(chatMessage)); ``` -The only requisite to have your openAI types ready to use is to add the openAI dependency to your `build.gradle` file: +`OpenAiChatGenerationModel` is pre-configured to use the GPT-3.5 model by default, but if you have API access to more advanced models, you can specify them, as well as the temperature, maxTokens using the alternate constructors: -```gradle -dependencies { - implementation 'com.theokanning.openai-gpt3-java:service:0.16.0' - ... -} +```java +// There's a convenience constructor that allows setting the model exclusively +OpenAiChatGenerationModel modelWithGPT4 = new OpenAiChatGenerationModel("", OpenAiModels.GPT_4); +OpenAiChatGenerationModel modelWithGPT316K = new OpenAiChatGenerationModel("", OpenAiModels.GPT_3_5_CONTEXT_16K); + +// Or you can use the full constructor to set the model, temperature and maxTokens +OpenAiChatGenerationModel customModel = new OpenAiChatGenerationModel("", OpenAiModels.GPT_3_5_CONTEXT_16K, 0.5, 100); ``` diff --git a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java index a1397f3..a90f3d1 100644 --- a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java +++ b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java @@ -1,5 +1,7 @@ package com.theagilemonkeys.ellmental.textgeneration; -public abstract class TextGenerationService { - public abstract String generate(Input input, Double temperature, int maxTokens, ProviderParameters parameters); +import java.util.List; + +public abstract class TextGenerationService { + public abstract String generate(List input); } diff --git a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModel.java b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModel.java new file mode 100644 index 0000000..7f419cb --- /dev/null +++ b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModel.java @@ -0,0 +1,83 @@ +package com.theagilemonkeys.ellmental.textgeneration.openai; + +import com.theagilemonkeys.ellmental.textgeneration.TextGenerationService; +import com.theokanning.openai.completion.chat.ChatCompletionChoice; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.service.OpenAiService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.List; + +import static com.theagilemonkeys.ellmental.textgeneration.openai.errors.Constants.NO_CONTENT_FOUND_OPEN_AI; + +public class OpenAiChatGenerationModel extends TextGenerationService { + private static final Double DEFAULT_TEMPERATURE = 0.7; + private static final int DEFAULT_MAX_TOKENS = 3000; + private static final Logger log = LoggerFactory.getLogger(OpenAiChatGenerationModel.class); + private final Double temperature; + private final int maxTokens; + private final OpenAiModels model; + // The OpenAI client is package-private to allow injecting a mock in tests + OpenAiService openAiService; + + /** + * Constructor for OpenAiTextGenerationModel that uses default values for temperature (0.7), + * maxTokens (3000) and model (GPT-3.5). + * @param openAiKey OpenAI API key + */ + public OpenAiChatGenerationModel(String openAiKey) { + this(openAiKey, OpenAiModels.GPT_3_5, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS); + } + + /** + * Constructor for OpenAiTextGenerationModel that explicitly sets the model, but uses the default values for + * temperature (0.7) and maxTokens (3000). + * @param openAiKey OpenAI API key + * @param model Model to use for the chat generation + */ + public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model) { + this(openAiKey, model, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS); + } + + /** + * Constructor for OpenAiTextGenerationModel that explicitly sets the temperature, maxTokens and model. + * @param openAiKey OpenAI API key + * @param temperature Temperature to use for the chat generation + * @param maxTokens Maximum number of tokens to use for the chat generation + * @param model Model to use for the chat generation + */ + public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model, Double temperature, int maxTokens) { + this.openAiService = new OpenAiService(openAiKey, Duration.ofSeconds(240)); + this.temperature = temperature; + this.maxTokens = maxTokens; + this.model = model; + } + + /** + * Generates a chat response using the OpenAI API. + * @param chatMessages List of chat messages to use as context for the response + * @return Generated chat response + */ + @Override + public String generate(List chatMessages) { + log.debug("Generating chat response for chat messages {}", chatMessages); + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .messages(chatMessages) + .temperature(temperature) + .maxTokens(maxTokens) + .model(model.getCodename()) + .build(); + + ChatCompletionChoice chatCompletionChoice = openAiService.createChatCompletion(chatCompletionRequest).getChoices().get(0); + String chatCompletionContent = chatCompletionChoice.getMessage().getContent(); + + log.debug("Chat completion response is {}", chatCompletionContent); + + return !chatCompletionContent.isEmpty() ? + chatCompletionContent : + String.format(NO_CONTENT_FOUND_OPEN_AI, chatMessages); + } +} diff --git a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationService.java b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationService.java deleted file mode 100644 index a16ffb1..0000000 --- a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationService.java +++ /dev/null @@ -1,48 +0,0 @@ -package com.theagilemonkeys.ellmental.textgeneration.openai; - -import com.theagilemonkeys.ellmental.textgeneration.TextGenerationService; -import com.theokanning.openai.completion.chat.ChatCompletionChoice; -import com.theokanning.openai.completion.chat.ChatCompletionRequest; -import com.theokanning.openai.completion.chat.ChatMessage; -import com.theokanning.openai.service.OpenAiService; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.time.Duration; -import java.util.List; - -import static com.theagilemonkeys.ellmental.textgeneration.openai.errors.Constants.NO_CONTENT_FOUND_OPEN_AI; - -public class OpenAiTextGenerationService extends TextGenerationService, OpenAiModels> { - private static final Logger log = LoggerFactory.getLogger(OpenAiTextGenerationService.class); - private final OpenAiService client; - - public OpenAiTextGenerationService(String openAiKey) { - this.client = new OpenAiService(openAiKey, Duration.ofSeconds(240)); - } - - // Only visible for testing purposes - OpenAiTextGenerationService(OpenAiService client) { - this.client = client; - } - - @Override - public String generate(List chatMessages, Double temperature, int maxTokens, OpenAiModels model) { - log.debug("Generating chat response for chat messages {}", chatMessages); - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .messages(chatMessages) - .temperature(temperature) - .maxTokens(maxTokens) - .model(model.getCodename()) - .build(); - - ChatCompletionChoice chatCompletionChoice = client.createChatCompletion(chatCompletionRequest).getChoices().get(0); - String chatCompletionContent = chatCompletionChoice.getMessage().getContent(); - - log.debug("Chat completion response is {}", chatCompletionContent); - - return !chatCompletionContent.isEmpty() ? - chatCompletionContent : - String.format(NO_CONTENT_FOUND_OPEN_AI, chatMessages); - } -} diff --git a/modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationServiceTest.java b/modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModelTest.java similarity index 83% rename from modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationServiceTest.java rename to modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModelTest.java index 3e9a180..ce5fd4b 100644 --- a/modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationServiceTest.java +++ b/modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModelTest.java @@ -5,11 +5,11 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.service.OpenAiService; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InjectMocks; import org.mockito.Mock; -import org.mockito.Spy; import org.mockito.junit.jupiter.MockitoExtension; import java.util.List; @@ -20,18 +20,22 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) -public class OpenAiTextGenerationServiceTest { - +public class OpenAiChatGenerationModelTest { @Mock private OpenAiService openAiService; - @InjectMocks - private OpenAiTextGenerationService openAiTextGenerationService; + private OpenAiChatGenerationModel openAiChatGenerationModel; private final List chatMessages = List.of(new ChatMessage("user", "How can I rank up in Rocket League?")); private final Double temperature = 0.1; private final int maxTokens = 3000; private final OpenAiModels model = OpenAiModels.GPT_4; + @BeforeEach + public void setUp() { + openAiChatGenerationModel = new OpenAiChatGenerationModel("openAiKey", model, temperature, maxTokens); + openAiChatGenerationModel.openAiService = openAiService; + } + @Test public void testGenerate() { String chatResult = "test content"; @@ -41,7 +45,7 @@ public void testGenerate() { result.setChoices(List.of(chatCompletionChoice)); when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(result); - String generatedText = openAiTextGenerationService.generate(chatMessages, temperature, maxTokens, model); + String generatedText = openAiChatGenerationModel.generate(chatMessages); assertEquals(chatResult, generatedText); @@ -56,7 +60,7 @@ public void testGenerateWhenContentNotFound() { emptyResult.setChoices(List.of(chatCompletionChoice)); when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(emptyResult); - String generatedText = openAiTextGenerationService.generate(chatMessages, temperature, maxTokens, model); + String generatedText = openAiChatGenerationModel.generate(chatMessages); assertEquals(generatedText, String.format(NO_CONTENT_FOUND_OPEN_AI, chatMessages));