Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface improvements on the Text Generator implementation #30

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .idea/gradle.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 15 additions & 14 deletions docs-site/docs/03_components/03_text_generation.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
# Text generation module

The text generation module aims to give the user an abstraction to use any model to give final answers to user questions.
The text generation module provides an abstraction to use Large Language Models (LLMs) that, given a prompt, complete it with the most useful text possible. LLMs can be used for a very wide range of use cases, including summarizing text, drafting an essay, or getting answers for general knowledge questions.

Below, you can find the current implementations available.
eLLMental provides the following implementations:

## OpenAI implementation
## `OpenAiChatGenerationModel`

You can use the `OpenAiTextGenerationService` to work with openAI chat messaging system:
The `OpenAiChatGenerationModel` is pre-configured to simulate a chat conversation. In this case, the model accepts a list of `ChatMessage` objects, where each message has a `user` and a `text` field. Then, the model will generate a response to the last message in the list. If you have used [ChatGPT](https://openai.com/chatgpt) before, the behavior will be very similar.

```java
Double temperature = 0.1;
int maxTokens = 3000;
ChatMessage chatMessage = new ChatMessage("user", "what is an LLM?");

TextGenerationService<List<ChatMessage>, OpenAiModels> service = new OpenAiTextGenerationService("api key");
String result = service.generate(List.of(chatMessage), temperature, maxTokens, OpenAiModels.GPT_3_5);
// Using the simplified constructor, the library will set sensible defaults for the temperature, maxTokens and model.
OpenAiChatGenerationModel model = new OpenAiChatGenerationModel("<api key>");
String response = service.generate(List.of(chatMessage));
```

The only requisite to have your openAI types ready to use is to add the openAI dependency to your `build.gradle` file:
`OpenAiChatGenerationModel` is pre-configured to use the GPT-3.5 model by default, but if you have API access to more advanced models, you can specify them, as well as the temperature, maxTokens using the alternate constructors:

```gradle
dependencies {
implementation 'com.theokanning.openai-gpt3-java:service:0.16.0'
...
}
```java
// There's a convenience constructor that allows setting the model exclusively
OpenAiChatGenerationModel modelWithGPT4 = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_4);
OpenAiChatGenerationModel modelWithGPT316K = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_3_5_CONTEXT_16K);

// Or you can use the full constructor to set the model, temperature and maxTokens
OpenAiChatGenerationModel customModel = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_3_5_CONTEXT_16K, 0.5, 100);
```
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.theagilemonkeys.ellmental.textgeneration;

public abstract class TextGenerationService<Input, ProviderParameters> {
public abstract String generate(Input input, Double temperature, int maxTokens, ProviderParameters parameters);
import java.util.List;

public abstract class TextGenerationService<Input> {
public abstract String generate(List<Input> input);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.theagilemonkeys.ellmental.textgeneration.openai;

import com.theagilemonkeys.ellmental.textgeneration.TextGenerationService;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.List;

import static com.theagilemonkeys.ellmental.textgeneration.openai.errors.Constants.NO_CONTENT_FOUND_OPEN_AI;

public class OpenAiChatGenerationModel extends TextGenerationService<ChatMessage> {
private static final Double DEFAULT_TEMPERATURE = 0.7;
private static final int DEFAULT_MAX_TOKENS = 3000;
private static final Logger log = LoggerFactory.getLogger(OpenAiChatGenerationModel.class);
private final Double temperature;
private final int maxTokens;
private final OpenAiModels model;
// The OpenAI client is package-private to allow injecting a mock in tests
OpenAiService openAiService;

/**
* Constructor for OpenAiTextGenerationModel that uses default values for temperature (0.7),
* maxTokens (3000) and model (GPT-3.5).
* @param openAiKey OpenAI API key
*/
public OpenAiChatGenerationModel(String openAiKey) {
this(openAiKey, OpenAiModels.GPT_3_5, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS);
}

/**
* Constructor for OpenAiTextGenerationModel that explicitly sets the model, but uses the default values for
* temperature (0.7) and maxTokens (3000).
* @param openAiKey OpenAI API key
* @param model Model to use for the chat generation
*/
public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model) {
this(openAiKey, model, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS);
}

/**
* Constructor for OpenAiTextGenerationModel that explicitly sets the temperature, maxTokens and model.
* @param openAiKey OpenAI API key
* @param temperature Temperature to use for the chat generation
* @param maxTokens Maximum number of tokens to use for the chat generation
* @param model Model to use for the chat generation
*/
public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model, Double temperature, int maxTokens) {
this.openAiService = new OpenAiService(openAiKey, Duration.ofSeconds(240));
this.temperature = temperature;
this.maxTokens = maxTokens;
this.model = model;
}

/**
* Generates a chat response using the OpenAI API.
* @param chatMessages List of chat messages to use as context for the response
* @return Generated chat response
*/
@Override
public String generate(List<ChatMessage> chatMessages) {
log.debug("Generating chat response for chat messages {}", chatMessages);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.messages(chatMessages)
.temperature(temperature)
.maxTokens(maxTokens)
.model(model.getCodename())
.build();

ChatCompletionChoice chatCompletionChoice = openAiService.createChatCompletion(chatCompletionRequest).getChoices().get(0);
String chatCompletionContent = chatCompletionChoice.getMessage().getContent();

log.debug("Chat completion response is {}", chatCompletionContent);

return !chatCompletionContent.isEmpty() ?
chatCompletionContent :
String.format(NO_CONTENT_FOUND_OPEN_AI, chatMessages);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;

import java.util.List;
Expand All @@ -20,18 +20,22 @@
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class OpenAiTextGenerationServiceTest {

public class OpenAiChatGenerationModelTest {
@Mock
private OpenAiService openAiService;
@InjectMocks
private OpenAiTextGenerationService openAiTextGenerationService;
private OpenAiChatGenerationModel openAiChatGenerationModel;

private final List<ChatMessage> chatMessages = List.of(new ChatMessage("user", "How can I rank up in Rocket League?"));
private final Double temperature = 0.1;
private final int maxTokens = 3000;
private final OpenAiModels model = OpenAiModels.GPT_4;

@BeforeEach
public void setUp() {
openAiChatGenerationModel = new OpenAiChatGenerationModel("openAiKey", model, temperature, maxTokens);
openAiChatGenerationModel.openAiService = openAiService;
}

@Test
public void testGenerate() {
String chatResult = "test content";
Expand All @@ -41,7 +45,7 @@ public void testGenerate() {
result.setChoices(List.of(chatCompletionChoice));
when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(result);

String generatedText = openAiTextGenerationService.generate(chatMessages, temperature, maxTokens, model);
String generatedText = openAiChatGenerationModel.generate(chatMessages);

assertEquals(chatResult, generatedText);

Expand All @@ -56,7 +60,7 @@ public void testGenerateWhenContentNotFound() {
emptyResult.setChoices(List.of(chatCompletionChoice));
when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(emptyResult);

String generatedText = openAiTextGenerationService.generate(chatMessages, temperature, maxTokens, model);
String generatedText = openAiChatGenerationModel.generate(chatMessages);

assertEquals(generatedText, String.format(NO_CONTENT_FOUND_OPEN_AI, chatMessages));

Expand Down