diff --git a/.idea/gradle.xml b/.idea/gradle.xml
index 3bba60c..056658a 100644
--- a/.idea/gradle.xml
+++ b/.idea/gradle.xml
@@ -14,7 +14,9 @@
+
+
diff --git a/docs-site/docs/03_components/03_text_generation.md b/docs-site/docs/03_components/03_text_generation.md
index 747f88b..53aa1f1 100644
--- a/docs-site/docs/03_components/03_text_generation.md
+++ b/docs-site/docs/03_components/03_text_generation.md
@@ -1,27 +1,28 @@
# Text generation module
-The text generation module aims to give the user an abstraction to use any model to give final answers to user questions.
+The text generation module provides an abstraction to use Large Language Models (LLMs) that, given a prompt, complete it with the most useful text possible. LLMs can be used for a very wide range of use cases, including summarizing text, drafting an essay, or getting answers for general knowledge questions.
-Below, you can find the current implementations available.
+eLLMental provides the following implementations:
-## OpenAI implementation
+## `OpenAiChatGenerationModel`
-You can use the `OpenAiTextGenerationService` to work with openAI chat messaging system:
+The `OpenAiChatGenerationModel` is pre-configured to simulate a chat conversation. In this case, the model accepts a list of `ChatMessage` objects, where each message has a `user` and a `text` field. Then, the model will generate a response to the last message in the list. If you have used [ChatGPT](https://openai.com/chatgpt) before, the behavior will be very similar.
```java
-Double temperature = 0.1;
-int maxTokens = 3000;
ChatMessage chatMessage = new ChatMessage("user", "what is an LLM?");
-TextGenerationService, OpenAiModels> service = new OpenAiTextGenerationService("api key");
-String result = service.generate(List.of(chatMessage), temperature, maxTokens, OpenAiModels.GPT_3_5);
+// Using the simplified constructor, the library will set sensible defaults for the temperature, maxTokens and model.
+OpenAiChatGenerationModel model = new OpenAiChatGenerationModel("");
+String response = service.generate(List.of(chatMessage));
```
-The only requisite to have your openAI types ready to use is to add the openAI dependency to your `build.gradle` file:
+`OpenAiChatGenerationModel` is pre-configured to use the GPT-3.5 model by default, but if you have API access to more advanced models, you can specify them, as well as the temperature, maxTokens using the alternate constructors:
-```gradle
-dependencies {
- implementation 'com.theokanning.openai-gpt3-java:service:0.16.0'
- ...
-}
+```java
+// There's a convenience constructor that allows setting the model exclusively
+OpenAiChatGenerationModel modelWithGPT4 = new OpenAiChatGenerationModel("", OpenAiModels.GPT_4);
+OpenAiChatGenerationModel modelWithGPT316K = new OpenAiChatGenerationModel("", OpenAiModels.GPT_3_5_CONTEXT_16K);
+
+// Or you can use the full constructor to set the model, temperature and maxTokens
+OpenAiChatGenerationModel customModel = new OpenAiChatGenerationModel("", OpenAiModels.GPT_3_5_CONTEXT_16K, 0.5, 100);
```
diff --git a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java
index a1397f3..a90f3d1 100644
--- a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java
+++ b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java
@@ -1,5 +1,7 @@
package com.theagilemonkeys.ellmental.textgeneration;
-public abstract class TextGenerationService {
- public abstract String generate(Input input, Double temperature, int maxTokens, ProviderParameters parameters);
+import java.util.List;
+
+public abstract class TextGenerationService {
+ public abstract String generate(List input);
}
diff --git a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModel.java b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModel.java
new file mode 100644
index 0000000..7f419cb
--- /dev/null
+++ b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModel.java
@@ -0,0 +1,83 @@
+package com.theagilemonkeys.ellmental.textgeneration.openai;
+
+import com.theagilemonkeys.ellmental.textgeneration.TextGenerationService;
+import com.theokanning.openai.completion.chat.ChatCompletionChoice;
+import com.theokanning.openai.completion.chat.ChatCompletionRequest;
+import com.theokanning.openai.completion.chat.ChatMessage;
+import com.theokanning.openai.service.OpenAiService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.time.Duration;
+import java.util.List;
+
+import static com.theagilemonkeys.ellmental.textgeneration.openai.errors.Constants.NO_CONTENT_FOUND_OPEN_AI;
+
+public class OpenAiChatGenerationModel extends TextGenerationService {
+ private static final Double DEFAULT_TEMPERATURE = 0.7;
+ private static final int DEFAULT_MAX_TOKENS = 3000;
+ private static final Logger log = LoggerFactory.getLogger(OpenAiChatGenerationModel.class);
+ private final Double temperature;
+ private final int maxTokens;
+ private final OpenAiModels model;
+ // The OpenAI client is package-private to allow injecting a mock in tests
+ OpenAiService openAiService;
+
+ /**
+ * Constructor for OpenAiTextGenerationModel that uses default values for temperature (0.7),
+ * maxTokens (3000) and model (GPT-3.5).
+ * @param openAiKey OpenAI API key
+ */
+ public OpenAiChatGenerationModel(String openAiKey) {
+ this(openAiKey, OpenAiModels.GPT_3_5, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS);
+ }
+
+ /**
+ * Constructor for OpenAiTextGenerationModel that explicitly sets the model, but uses the default values for
+ * temperature (0.7) and maxTokens (3000).
+ * @param openAiKey OpenAI API key
+ * @param model Model to use for the chat generation
+ */
+ public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model) {
+ this(openAiKey, model, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS);
+ }
+
+ /**
+ * Constructor for OpenAiTextGenerationModel that explicitly sets the temperature, maxTokens and model.
+ * @param openAiKey OpenAI API key
+ * @param temperature Temperature to use for the chat generation
+ * @param maxTokens Maximum number of tokens to use for the chat generation
+ * @param model Model to use for the chat generation
+ */
+ public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model, Double temperature, int maxTokens) {
+ this.openAiService = new OpenAiService(openAiKey, Duration.ofSeconds(240));
+ this.temperature = temperature;
+ this.maxTokens = maxTokens;
+ this.model = model;
+ }
+
+ /**
+ * Generates a chat response using the OpenAI API.
+ * @param chatMessages List of chat messages to use as context for the response
+ * @return Generated chat response
+ */
+ @Override
+ public String generate(List chatMessages) {
+ log.debug("Generating chat response for chat messages {}", chatMessages);
+ ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
+ .messages(chatMessages)
+ .temperature(temperature)
+ .maxTokens(maxTokens)
+ .model(model.getCodename())
+ .build();
+
+ ChatCompletionChoice chatCompletionChoice = openAiService.createChatCompletion(chatCompletionRequest).getChoices().get(0);
+ String chatCompletionContent = chatCompletionChoice.getMessage().getContent();
+
+ log.debug("Chat completion response is {}", chatCompletionContent);
+
+ return !chatCompletionContent.isEmpty() ?
+ chatCompletionContent :
+ String.format(NO_CONTENT_FOUND_OPEN_AI, chatMessages);
+ }
+}
diff --git a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationService.java b/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationService.java
deleted file mode 100644
index a16ffb1..0000000
--- a/modules/textgeneration/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationService.java
+++ /dev/null
@@ -1,48 +0,0 @@
-package com.theagilemonkeys.ellmental.textgeneration.openai;
-
-import com.theagilemonkeys.ellmental.textgeneration.TextGenerationService;
-import com.theokanning.openai.completion.chat.ChatCompletionChoice;
-import com.theokanning.openai.completion.chat.ChatCompletionRequest;
-import com.theokanning.openai.completion.chat.ChatMessage;
-import com.theokanning.openai.service.OpenAiService;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.time.Duration;
-import java.util.List;
-
-import static com.theagilemonkeys.ellmental.textgeneration.openai.errors.Constants.NO_CONTENT_FOUND_OPEN_AI;
-
-public class OpenAiTextGenerationService extends TextGenerationService, OpenAiModels> {
- private static final Logger log = LoggerFactory.getLogger(OpenAiTextGenerationService.class);
- private final OpenAiService client;
-
- public OpenAiTextGenerationService(String openAiKey) {
- this.client = new OpenAiService(openAiKey, Duration.ofSeconds(240));
- }
-
- // Only visible for testing purposes
- OpenAiTextGenerationService(OpenAiService client) {
- this.client = client;
- }
-
- @Override
- public String generate(List chatMessages, Double temperature, int maxTokens, OpenAiModels model) {
- log.debug("Generating chat response for chat messages {}", chatMessages);
- ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
- .messages(chatMessages)
- .temperature(temperature)
- .maxTokens(maxTokens)
- .model(model.getCodename())
- .build();
-
- ChatCompletionChoice chatCompletionChoice = client.createChatCompletion(chatCompletionRequest).getChoices().get(0);
- String chatCompletionContent = chatCompletionChoice.getMessage().getContent();
-
- log.debug("Chat completion response is {}", chatCompletionContent);
-
- return !chatCompletionContent.isEmpty() ?
- chatCompletionContent :
- String.format(NO_CONTENT_FOUND_OPEN_AI, chatMessages);
- }
-}
diff --git a/modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationServiceTest.java b/modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModelTest.java
similarity index 83%
rename from modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationServiceTest.java
rename to modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModelTest.java
index 3e9a180..ce5fd4b 100644
--- a/modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiTextGenerationServiceTest.java
+++ b/modules/textgeneration/src/test/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModelTest.java
@@ -5,11 +5,11 @@
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
+import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
-import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.List;
@@ -20,18 +20,22 @@
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
-public class OpenAiTextGenerationServiceTest {
-
+public class OpenAiChatGenerationModelTest {
@Mock
private OpenAiService openAiService;
- @InjectMocks
- private OpenAiTextGenerationService openAiTextGenerationService;
+ private OpenAiChatGenerationModel openAiChatGenerationModel;
private final List chatMessages = List.of(new ChatMessage("user", "How can I rank up in Rocket League?"));
private final Double temperature = 0.1;
private final int maxTokens = 3000;
private final OpenAiModels model = OpenAiModels.GPT_4;
+ @BeforeEach
+ public void setUp() {
+ openAiChatGenerationModel = new OpenAiChatGenerationModel("openAiKey", model, temperature, maxTokens);
+ openAiChatGenerationModel.openAiService = openAiService;
+ }
+
@Test
public void testGenerate() {
String chatResult = "test content";
@@ -41,7 +45,7 @@ public void testGenerate() {
result.setChoices(List.of(chatCompletionChoice));
when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(result);
- String generatedText = openAiTextGenerationService.generate(chatMessages, temperature, maxTokens, model);
+ String generatedText = openAiChatGenerationModel.generate(chatMessages);
assertEquals(chatResult, generatedText);
@@ -56,7 +60,7 @@ public void testGenerateWhenContentNotFound() {
emptyResult.setChoices(List.of(chatCompletionChoice));
when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(emptyResult);
- String generatedText = openAiTextGenerationService.generate(chatMessages, temperature, maxTokens, model);
+ String generatedText = openAiChatGenerationModel.generate(chatMessages);
assertEquals(generatedText, String.format(NO_CONTENT_FOUND_OPEN_AI, chatMessages));