-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #28 from theam/add_text_generation_module
add text generation module
- Loading branch information
Showing
10 changed files
with
265 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Text generation module | ||
|
||
The text generation module provides an abstraction to use Large Language Models (LLMs) that, given a prompt, complete it with the most useful text possible. LLMs can be used for a very wide range of use cases, including summarizing text, drafting an essay, or getting answers for general knowledge questions. | ||
|
||
eLLMental provides the following implementations: | ||
|
||
## `OpenAiChatGenerationModel` | ||
|
||
The `OpenAiChatGenerationModel` is pre-configured to simulate a chat conversation. In this case, the model accepts a list of `ChatMessage` objects, where each message has a `user` and a `text` field. Then, the model will generate a response to the last message in the list. If you have used [ChatGPT](https://openai.com/chatgpt) before, the behavior will be very similar. | ||
|
||
```java | ||
ChatMessage chatMessage = new ChatMessage("user", "what is an LLM?"); | ||
|
||
// Using the simplified constructor, the library will set sensible defaults for the temperature, maxTokens and model. | ||
OpenAiChatGenerationModel model = new OpenAiChatGenerationModel("<api key>"); | ||
String response = service.generate(List.of(chatMessage)); | ||
``` | ||
|
||
`OpenAiChatGenerationModel` is pre-configured to use the GPT-3.5 model by default, but if you have API access to more advanced models, you can specify them, as well as the temperature, maxTokens using the alternate constructors: | ||
|
||
```java | ||
// There's a convenience constructor that allows setting the model exclusively | ||
OpenAiChatGenerationModel modelWithGPT4 = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_4); | ||
OpenAiChatGenerationModel modelWithGPT316K = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_3_5_CONTEXT_16K); | ||
|
||
// Or you can use the full constructor to set the model, temperature and maxTokens | ||
OpenAiChatGenerationModel customModel = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_3_5_CONTEXT_16K, 0.5, 100); | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
plugins { | ||
id 'java' | ||
id "io.freefair.lombok" version "8.3" | ||
} | ||
|
||
group = 'com.theagilemonkeys.ellemental' | ||
version = '1.0-SNAPSHOT' | ||
|
||
repositories { | ||
mavenCentral() | ||
} | ||
|
||
dependencies { | ||
implementation 'ch.qos.logback:logback-core:1.4.11' | ||
implementation 'org.slf4j:slf4j-api:2.0.9' | ||
implementation 'com.theokanning.openai-gpt3-java:service:0.16.0' | ||
|
||
testImplementation 'ch.qos.logback:logback-classic:1.4.11' | ||
testImplementation platform('org.junit:junit-bom:5.9.1') | ||
testImplementation 'org.junit.jupiter:junit-jupiter' | ||
testImplementation 'org.mockito:mockito-core:3.+' | ||
testImplementation 'org.mockito:mockito-junit-jupiter:5.5.0' | ||
} | ||
|
||
test { | ||
useJUnitPlatform() | ||
} |
7 changes: 7 additions & 0 deletions
7
...ion/src/main/java/com/theagilemonkeys/ellmental/textgeneration/TextGenerationService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
package com.theagilemonkeys.ellmental.textgeneration; | ||
|
||
import java.util.List; | ||
|
||
public abstract class TextGenerationService<Input> { | ||
public abstract String generate(List<Input> input); | ||
} |
84 changes: 84 additions & 0 deletions
84
...n/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package com.theagilemonkeys.ellmental.textgeneration.openai; | ||
|
||
import com.theagilemonkeys.ellmental.textgeneration.TextGenerationService; | ||
import com.theagilemonkeys.ellmental.textgeneration.openai.errors.NoContentFoundException; | ||
import com.theokanning.openai.completion.chat.ChatCompletionChoice; | ||
import com.theokanning.openai.completion.chat.ChatCompletionRequest; | ||
import com.theokanning.openai.completion.chat.ChatMessage; | ||
import com.theokanning.openai.service.OpenAiService; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.time.Duration; | ||
import java.util.List; | ||
|
||
public class OpenAiChatGenerationModel extends TextGenerationService<ChatMessage> { | ||
private static final Double DEFAULT_TEMPERATURE = 0.7; | ||
private static final int DEFAULT_MAX_TOKENS = 3000; | ||
private static final Logger log = LoggerFactory.getLogger(OpenAiChatGenerationModel.class); | ||
private final Double temperature; | ||
private final int maxTokens; | ||
private final OpenAiModels model; | ||
// The OpenAI client is package-private to allow injecting a mock in tests | ||
OpenAiService openAiService; | ||
|
||
/** | ||
* Constructor for OpenAiTextGenerationModel that uses default values for temperature (0.7), | ||
* maxTokens (3000) and model (GPT-3.5). | ||
* @param openAiKey OpenAI API key | ||
*/ | ||
public OpenAiChatGenerationModel(String openAiKey) { | ||
this(openAiKey, OpenAiModels.GPT_3_5, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS); | ||
} | ||
|
||
/** | ||
* Constructor for OpenAiTextGenerationModel that explicitly sets the model, but uses the default values for | ||
* temperature (0.7) and maxTokens (3000). | ||
* @param openAiKey OpenAI API key | ||
* @param model Model to use for the chat generation | ||
*/ | ||
public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model) { | ||
this(openAiKey, model, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS); | ||
} | ||
|
||
/** | ||
* Constructor for OpenAiTextGenerationModel that explicitly sets the temperature, maxTokens and model. | ||
* @param openAiKey OpenAI API key | ||
* @param temperature Temperature to use for the chat generation | ||
* @param maxTokens Maximum number of tokens to use for the chat generation | ||
* @param model Model to use for the chat generation | ||
*/ | ||
public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model, Double temperature, int maxTokens) { | ||
this.openAiService = new OpenAiService(openAiKey, Duration.ofSeconds(240)); | ||
this.temperature = temperature; | ||
this.maxTokens = maxTokens; | ||
this.model = model; | ||
} | ||
|
||
/** | ||
* Generates a chat response using the OpenAI API. | ||
* @param chatMessages List of chat messages to use as context for the response | ||
* @return Generated chat response | ||
*/ | ||
@Override | ||
public String generate(List<ChatMessage> chatMessages) { | ||
log.debug("Generating chat response for chat messages {}", chatMessages); | ||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() | ||
.messages(chatMessages) | ||
.temperature(temperature) | ||
.maxTokens(maxTokens) | ||
.model(model.getCodename()) | ||
.build(); | ||
|
||
ChatCompletionChoice chatCompletionChoice = openAiService.createChatCompletion(chatCompletionRequest).getChoices().get(0); | ||
String chatCompletionContent = chatCompletionChoice.getMessage().getContent(); | ||
|
||
log.debug("Chat completion response is {}", chatCompletionContent); | ||
|
||
if (chatCompletionContent.isEmpty()) { | ||
throw new NoContentFoundException(chatMessages); | ||
} | ||
|
||
return chatCompletionContent; | ||
} | ||
} |
17 changes: 17 additions & 0 deletions
17
...ation/src/main/java/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiModels.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package com.theagilemonkeys.ellmental.textgeneration.openai; | ||
|
||
public enum OpenAiModels { | ||
GPT_3_5_CONTEXT_16K("gpt-3.5-turbo-16k"), | ||
GPT_3_5("gpt-3.5-turbo"), | ||
GPT_4("gpt-4"); | ||
|
||
private final String codename; | ||
|
||
OpenAiModels(String codename) { | ||
this.codename = codename; | ||
} | ||
|
||
public String getCodename() { | ||
return this.codename; | ||
} | ||
} |
13 changes: 13 additions & 0 deletions
13
...a/com/theagilemonkeys/ellmental/textgeneration/openai/errors/NoContentFoundException.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package com.theagilemonkeys.ellmental.textgeneration.openai.errors; | ||
|
||
import com.theokanning.openai.completion.chat.ChatMessage; | ||
|
||
import java.util.List; | ||
|
||
public class NoContentFoundException extends RuntimeException { | ||
private static final String NO_CONTENT_FOUND_MESSAGE = "No content found in response for messages %s"; | ||
|
||
public NoContentFoundException(List<ChatMessage> messages) { | ||
super(String.format(NO_CONTENT_FOUND_MESSAGE, messages)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<configuration> | ||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> | ||
<encoder> | ||
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern> | ||
</encoder> | ||
</appender> | ||
|
||
<root level="debug"> | ||
<appender-ref ref="STDOUT" /> | ||
</root> | ||
</configuration> |
76 changes: 76 additions & 0 deletions
76
...va/com/theagilemonkeys/ellmental/textgeneration/openai/OpenAiChatGenerationModelTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package com.theagilemonkeys.ellmental.textgeneration.openai; | ||
|
||
import com.theagilemonkeys.ellmental.textgeneration.openai.errors.NoContentFoundException; | ||
import com.theokanning.openai.completion.chat.ChatCompletionChoice; | ||
import com.theokanning.openai.completion.chat.ChatCompletionRequest; | ||
import com.theokanning.openai.completion.chat.ChatCompletionResult; | ||
import com.theokanning.openai.completion.chat.ChatMessage; | ||
import com.theokanning.openai.service.OpenAiService; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.ExtendWith; | ||
import org.mockito.Mock; | ||
import org.mockito.junit.jupiter.MockitoExtension; | ||
|
||
import java.util.List; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
import static org.junit.jupiter.api.Assertions.assertThrows; | ||
import static org.mockito.Mockito.verifyNoMoreInteractions; | ||
import static org.mockito.Mockito.when; | ||
|
||
@ExtendWith(MockitoExtension.class) | ||
public class OpenAiChatGenerationModelTest { | ||
@Mock | ||
private OpenAiService openAiService; | ||
private OpenAiChatGenerationModel openAiChatGenerationModel; | ||
|
||
private final List<ChatMessage> chatMessages = List.of(new ChatMessage("user", "How can I rank up in Rocket League?")); | ||
private final Double temperature = 0.1; | ||
private final int maxTokens = 3000; | ||
private final OpenAiModels model = OpenAiModels.GPT_4; | ||
|
||
@BeforeEach | ||
public void setUp() { | ||
openAiChatGenerationModel = new OpenAiChatGenerationModel("openAiKey", model, temperature, maxTokens); | ||
openAiChatGenerationModel.openAiService = openAiService; | ||
} | ||
|
||
@Test | ||
public void testGenerate() { | ||
String chatResult = "test content"; | ||
ChatCompletionChoice chatCompletionChoice = new ChatCompletionChoice(); | ||
chatCompletionChoice.setMessage(new ChatMessage("user", chatResult)); | ||
ChatCompletionResult result = new ChatCompletionResult(); | ||
result.setChoices(List.of(chatCompletionChoice)); | ||
when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(result); | ||
|
||
String generatedText = openAiChatGenerationModel.generate(chatMessages); | ||
|
||
assertEquals(chatResult, generatedText); | ||
|
||
verifyNoMoreInteractions(openAiService); | ||
} | ||
|
||
@Test | ||
public void testGenerateWhenContentNotFound() { | ||
ChatCompletionChoice chatCompletionChoice = new ChatCompletionChoice(); | ||
chatCompletionChoice.setMessage(new ChatMessage("user", "")); | ||
ChatCompletionResult emptyResult = new ChatCompletionResult(); | ||
emptyResult.setChoices(List.of(chatCompletionChoice)); | ||
when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(emptyResult); | ||
|
||
assertThrows(NoContentFoundException.class, () -> openAiChatGenerationModel.generate(chatMessages)); | ||
|
||
verifyNoMoreInteractions(openAiService); | ||
} | ||
|
||
private ChatCompletionRequest getChatCompletion() { | ||
return ChatCompletionRequest.builder() | ||
.messages(chatMessages) | ||
.temperature(temperature) | ||
.maxTokens(maxTokens) | ||
.model(model.getCodename()) | ||
.build(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters