Skip to content

Commit

Permalink
Merge pull request #28 from theam/add_text_generation_module
Browse files Browse the repository at this point in the history
add text generation module
  • Loading branch information
juanjoman authored Sep 6, 2023
2 parents 7b46e24 + 3f7190a commit 5f3e171
Show file tree
Hide file tree
Showing 10 changed files with 265 additions and 0 deletions.
1 change: 1 addition & 0 deletions .idea/gradle.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions docs-site/docs/03_components/03_text_generation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Text generation module

The text generation module provides an abstraction to use Large Language Models (LLMs) that, given a prompt, complete it with the most useful text possible. LLMs can be used for a very wide range of use cases, including summarizing text, drafting an essay, or getting answers for general knowledge questions.

eLLMental provides the following implementations:

## `OpenAiChatGenerationModel`

The `OpenAiChatGenerationModel` is pre-configured to simulate a chat conversation. In this case, the model accepts a list of `ChatMessage` objects, where each message has a `user` and a `text` field. Then, the model will generate a response to the last message in the list. If you have used [ChatGPT](https://openai.com/chatgpt) before, the behavior will be very similar.

```java
ChatMessage chatMessage = new ChatMessage("user", "what is an LLM?");

// Using the simplified constructor, the library will set sensible defaults for the temperature, maxTokens and model.
OpenAiChatGenerationModel model = new OpenAiChatGenerationModel("<api key>");
String response = service.generate(List.of(chatMessage));
```

`OpenAiChatGenerationModel` is pre-configured to use the GPT-3.5 model by default, but if you have API access to more advanced models, you can specify them, as well as the temperature, maxTokens using the alternate constructors:

```java
// There's a convenience constructor that allows setting the model exclusively
OpenAiChatGenerationModel modelWithGPT4 = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_4);
OpenAiChatGenerationModel modelWithGPT316K = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_3_5_CONTEXT_16K);

// Or you can use the full constructor to set the model, temperature and maxTokens
OpenAiChatGenerationModel customModel = new OpenAiChatGenerationModel("<api key>", OpenAiModels.GPT_3_5_CONTEXT_16K, 0.5, 100);
```
27 changes: 27 additions & 0 deletions modules/textgeneration/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
plugins {
id 'java'
id "io.freefair.lombok" version "8.3"
}

group = 'com.theagilemonkeys.ellemental'
version = '1.0-SNAPSHOT'

repositories {
mavenCentral()
}

dependencies {
implementation 'ch.qos.logback:logback-core:1.4.11'
implementation 'org.slf4j:slf4j-api:2.0.9'
implementation 'com.theokanning.openai-gpt3-java:service:0.16.0'

testImplementation 'ch.qos.logback:logback-classic:1.4.11'
testImplementation platform('org.junit:junit-bom:5.9.1')
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'org.mockito:mockito-core:3.+'
testImplementation 'org.mockito:mockito-junit-jupiter:5.5.0'
}

test {
useJUnitPlatform()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.theagilemonkeys.ellmental.textgeneration;

import java.util.List;

public abstract class TextGenerationService<Input> {
public abstract String generate(List<Input> input);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package com.theagilemonkeys.ellmental.textgeneration.openai;

import com.theagilemonkeys.ellmental.textgeneration.TextGenerationService;
import com.theagilemonkeys.ellmental.textgeneration.openai.errors.NoContentFoundException;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.List;

public class OpenAiChatGenerationModel extends TextGenerationService<ChatMessage> {
private static final Double DEFAULT_TEMPERATURE = 0.7;
private static final int DEFAULT_MAX_TOKENS = 3000;
private static final Logger log = LoggerFactory.getLogger(OpenAiChatGenerationModel.class);
private final Double temperature;
private final int maxTokens;
private final OpenAiModels model;
// The OpenAI client is package-private to allow injecting a mock in tests
OpenAiService openAiService;

/**
* Constructor for OpenAiTextGenerationModel that uses default values for temperature (0.7),
* maxTokens (3000) and model (GPT-3.5).
* @param openAiKey OpenAI API key
*/
public OpenAiChatGenerationModel(String openAiKey) {
this(openAiKey, OpenAiModels.GPT_3_5, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS);
}

/**
* Constructor for OpenAiTextGenerationModel that explicitly sets the model, but uses the default values for
* temperature (0.7) and maxTokens (3000).
* @param openAiKey OpenAI API key
* @param model Model to use for the chat generation
*/
public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model) {
this(openAiKey, model, DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS);
}

/**
* Constructor for OpenAiTextGenerationModel that explicitly sets the temperature, maxTokens and model.
* @param openAiKey OpenAI API key
* @param temperature Temperature to use for the chat generation
* @param maxTokens Maximum number of tokens to use for the chat generation
* @param model Model to use for the chat generation
*/
public OpenAiChatGenerationModel(String openAiKey, OpenAiModels model, Double temperature, int maxTokens) {
this.openAiService = new OpenAiService(openAiKey, Duration.ofSeconds(240));
this.temperature = temperature;
this.maxTokens = maxTokens;
this.model = model;
}

/**
* Generates a chat response using the OpenAI API.
* @param chatMessages List of chat messages to use as context for the response
* @return Generated chat response
*/
@Override
public String generate(List<ChatMessage> chatMessages) {
log.debug("Generating chat response for chat messages {}", chatMessages);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.messages(chatMessages)
.temperature(temperature)
.maxTokens(maxTokens)
.model(model.getCodename())
.build();

ChatCompletionChoice chatCompletionChoice = openAiService.createChatCompletion(chatCompletionRequest).getChoices().get(0);
String chatCompletionContent = chatCompletionChoice.getMessage().getContent();

log.debug("Chat completion response is {}", chatCompletionContent);

if (chatCompletionContent.isEmpty()) {
throw new NoContentFoundException(chatMessages);
}

return chatCompletionContent;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.theagilemonkeys.ellmental.textgeneration.openai;

public enum OpenAiModels {
GPT_3_5_CONTEXT_16K("gpt-3.5-turbo-16k"),
GPT_3_5("gpt-3.5-turbo"),
GPT_4("gpt-4");

private final String codename;

OpenAiModels(String codename) {
this.codename = codename;
}

public String getCodename() {
return this.codename;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.theagilemonkeys.ellmental.textgeneration.openai.errors;

import com.theokanning.openai.completion.chat.ChatMessage;

import java.util.List;

public class NoContentFoundException extends RuntimeException {
private static final String NO_CONTENT_FOUND_MESSAGE = "No content found in response for messages %s";

public NoContentFoundException(List<ChatMessage> messages) {
super(String.format(NO_CONTENT_FOUND_MESSAGE, messages));
}
}
11 changes: 11 additions & 0 deletions modules/textgeneration/src/main/resources/logback.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>

<root level="debug">
<appender-ref ref="STDOUT" />
</root>
</configuration>
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.theagilemonkeys.ellmental.textgeneration.openai;

import com.theagilemonkeys.ellmental.textgeneration.openai.errors.NoContentFoundException;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class OpenAiChatGenerationModelTest {
@Mock
private OpenAiService openAiService;
private OpenAiChatGenerationModel openAiChatGenerationModel;

private final List<ChatMessage> chatMessages = List.of(new ChatMessage("user", "How can I rank up in Rocket League?"));
private final Double temperature = 0.1;
private final int maxTokens = 3000;
private final OpenAiModels model = OpenAiModels.GPT_4;

@BeforeEach
public void setUp() {
openAiChatGenerationModel = new OpenAiChatGenerationModel("openAiKey", model, temperature, maxTokens);
openAiChatGenerationModel.openAiService = openAiService;
}

@Test
public void testGenerate() {
String chatResult = "test content";
ChatCompletionChoice chatCompletionChoice = new ChatCompletionChoice();
chatCompletionChoice.setMessage(new ChatMessage("user", chatResult));
ChatCompletionResult result = new ChatCompletionResult();
result.setChoices(List.of(chatCompletionChoice));
when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(result);

String generatedText = openAiChatGenerationModel.generate(chatMessages);

assertEquals(chatResult, generatedText);

verifyNoMoreInteractions(openAiService);
}

@Test
public void testGenerateWhenContentNotFound() {
ChatCompletionChoice chatCompletionChoice = new ChatCompletionChoice();
chatCompletionChoice.setMessage(new ChatMessage("user", ""));
ChatCompletionResult emptyResult = new ChatCompletionResult();
emptyResult.setChoices(List.of(chatCompletionChoice));
when(openAiService.createChatCompletion(getChatCompletion())).thenReturn(emptyResult);

assertThrows(NoContentFoundException.class, () -> openAiChatGenerationModel.generate(chatMessages));

verifyNoMoreInteractions(openAiService);
}

private ChatCompletionRequest getChatCompletion() {
return ChatCompletionRequest.builder()
.messages(chatMessages)
.temperature(temperature)
.maxTokens(maxTokens)
.model(model.getCodename())
.build();
}
}
1 change: 1 addition & 0 deletions settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include(
"modules:embeddingsgeneration",
"modules:embeddingsstore",
"modules:embeddingsspace",
"modules:textgeneration",
"examples:simplejava",
"examples:webcrawler"
)

0 comments on commit 5f3e171

Please sign in to comment.