Skip to content

Commit

Permalink
Add scalar parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
penguineer committed Oct 30, 2024
1 parent 39432d3 commit cfcb816
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/main/java/com/penguineering/hareairis/ai/AIChatService.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import com.penguineering.hareairis.model.ChatResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.stereotype.Service;

import java.util.Objects;

/**
* Service to handle chat requests.
*
Expand All @@ -31,7 +34,10 @@ public AIChatService(ChatClient.Builder chatClientBuilder) {
*/
public ChatResponse handleChatRequest(ChatRequest chatRequest) {
try {
AzureOpenAiChatOptions options = renderAzureOpenAiChatOptions(chatRequest);

ChatClient chatClient = chatClientBuilder
.defaultOptions(options)
.defaultSystem(chatRequest.getSystemMessage())
.build();
var chatResponse = chatClient
Expand Down Expand Up @@ -59,4 +65,21 @@ public ChatResponse handleChatRequest(ChatRequest chatRequest) {
throw new ChatException(response.getStatusCode(), e.getMessage());
}
}

private static AzureOpenAiChatOptions renderAzureOpenAiChatOptions(ChatRequest chatRequest) {
AzureOpenAiChatOptions options = new AzureOpenAiChatOptions();

if (Objects.nonNull(chatRequest.getMaxTokens()))
options.setMaxTokens(chatRequest.getMaxTokens());
if (Objects.nonNull(chatRequest.getTemperature()))
options.setTemperature(chatRequest.getTemperature());
if (Objects.nonNull(chatRequest.getTopP()))
options.setTopP(chatRequest.getTopP());
if (Objects.nonNull(chatRequest.getPresencePenalty()))
options.setPresencePenalty(chatRequest.getPresencePenalty());
if (Objects.nonNull(chatRequest.getFrequencyPenalty()))
options.setFrequencyPenalty(chatRequest.getFrequencyPenalty());

return options;
}
}
47 changes: 47 additions & 0 deletions src/main/java/com/penguineering/hareairis/model/ChatRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import lombok.Getter;
import lombok.NoArgsConstructor;

import java.util.Optional;

@NoArgsConstructor
@AllArgsConstructor
@Getter
Expand All @@ -16,4 +18,49 @@ public class ChatRequest {

@JsonProperty("prompt")
private String prompt = "";

/**
* The maximum number of tokens to generate.
*/
@JsonProperty(value = "max-tokens")
private Integer maxTokens = null;

/**
* The sampling temperature to use that controls the apparent creativity of generated
* completions. Higher values will make output more random while lower values will
* make results more focused and deterministic. It is not recommended to modify
* temperature and top_p for the same completions request as the interaction of these
* two settings is difficult to predict.
*/
@JsonProperty(value = "temperature")
private Double temperature = null;

/**
* An alternative to sampling with temperature called nucleus sampling. This value
* causes the model to consider the results of tokens with the provided probability
* mass. As an example, a value of 0.15 will cause only the tokens comprising the top
* 15% of probability mass to be considered. It is not recommended to modify
* temperature and top_p for the same completions request as the interaction of these
* two settings is difficult to predict.
*/
@JsonProperty(value = "top-p")
private Double topP = null;

/**
* A value that influences the probability of generated tokens appearing based on
* their existing presence in generated text. Positive values will make tokens less
* likely to appear when they already exist and increase the model's likelihood to
* output new topics.
*/
@JsonProperty(value = "presence-penalty")
private Double presencePenalty = null;

/**
* A value that influences the probability of generated tokens appearing based on
* their cumulative frequency in generated text. Positive values will make tokens less
* likely to appear as their frequency increases and decrease the likelihood of the
* model repeating the same statements verbatim.
*/
@JsonProperty(value = "frequency-penalty")
private Double frequencyPenalty = null;
}

0 comments on commit cfcb816

Please sign in to comment.