Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance the chat request configuration #4

Merged
merged 4 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions src/main/java/com/penguineering/hareairis/ai/AIChatService.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package com.penguineering.hareairis.ai;

import com.azure.core.exception.HttpResponseException;
import com.penguineering.hareairis.model.ChatException;
import com.penguineering.hareairis.model.ChatRequest;
import com.penguineering.hareairis.model.ChatResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.stereotype.Service;

import java.util.Objects;

/**
* Service to handle chat requests.
*
Expand All @@ -28,24 +33,53 @@ public AIChatService(ChatClient.Builder chatClientBuilder) {
* @return The chat response.
*/
public ChatResponse handleChatRequest(ChatRequest chatRequest) {
ChatClient chatClient = chatClientBuilder.build();
var chatResponse = chatClient
.prompt()
.user(chatRequest.getMessage())
.call()
.chatResponse();
try {
AzureOpenAiChatOptions options = renderAzureOpenAiChatOptions(chatRequest);

ChatClient chatClient = chatClientBuilder
.defaultOptions(options)
.defaultSystem(chatRequest.getSystemMessage())
.build();
var chatResponse = chatClient
.prompt()
.user(chatRequest.getPrompt())
.call()
.chatResponse();

String response = chatResponse.getResult().getOutput().getContent();

String response = chatResponse.getResult().getOutput().getContent();
logger.info("Received response from OpenAI: {}", response);

Long promptTokens = chatResponse.getMetadata().getUsage().getPromptTokens();
Long generationTokens = chatResponse.getMetadata().getUsage().getGenerationTokens();

return ChatResponse.builder()
.response(response)
.inputTokens(promptTokens.intValue())
.outputTokens(generationTokens.intValue())
.build();
} catch (IllegalArgumentException e) {
throw new ChatException(ChatException.Code.CODE_BAD_REQUEST, e.getMessage());
} catch (HttpResponseException e) {
var response = e.getResponse();
throw new ChatException(response.getStatusCode(), e.getMessage());
}
}

logger.info("Received response from OpenAI: {}", response);
private static AzureOpenAiChatOptions renderAzureOpenAiChatOptions(ChatRequest chatRequest) {
AzureOpenAiChatOptions options = new AzureOpenAiChatOptions();

Long promptTokens = chatResponse.getMetadata().getUsage().getPromptTokens();
Long generationTokens = chatResponse.getMetadata().getUsage().getGenerationTokens();
if (Objects.nonNull(chatRequest.getMaxTokens()))
options.setMaxTokens(chatRequest.getMaxTokens());
if (Objects.nonNull(chatRequest.getTemperature()))
options.setTemperature(chatRequest.getTemperature());
if (Objects.nonNull(chatRequest.getTopP()))
options.setTopP(chatRequest.getTopP());
if (Objects.nonNull(chatRequest.getPresencePenalty()))
options.setPresencePenalty(chatRequest.getPresencePenalty());
if (Objects.nonNull(chatRequest.getFrequencyPenalty()))
options.setFrequencyPenalty(chatRequest.getFrequencyPenalty());

return ChatResponse.builder()
.response(response)
.inputTokens(promptTokens.intValue())
.outputTokens(generationTokens.intValue())
.build();
return options;
}
}
34 changes: 6 additions & 28 deletions src/main/java/com/penguineering/hareairis/model/ChatError.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,19 @@
import lombok.Getter;
import lombok.NoArgsConstructor;

@AllArgsConstructor
@NoArgsConstructor
@Getter
@JsonIgnoreProperties(ignoreUnknown = true)
public class ChatError extends RuntimeException {
@Getter
@AllArgsConstructor
public enum Code {
CODE_BAD_REQUEST(400),
CODE_TOO_MANY_REQUESTS(429),
CODE_INTERNAL_SERVER_ERROR(500),
CODE_GATEWAY_TIMEOUT(504);

private final int code;

}

public class ChatError {
@JsonProperty("code")
private int code;

@JsonProperty("message")
private String message;

public ChatError(String message) {
super(message);
this.code = Code.CODE_INTERNAL_SERVER_ERROR.getCode();
this.message = message;
}

public ChatError(Code code, String message) {
super(message);
this.code = code.getCode();
this.message = message;
}

public boolean is5xxServerError() {
return code >= 500 && code < 600;
public ChatError(ChatException ex) {
this.code = ex.getCode();
this.message = ex.getMessage();
}
}
}
41 changes: 41 additions & 0 deletions src/main/java/com/penguineering/hareairis/model/ChatException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.penguineering.hareairis.model;

import com.fasterxml.jackson.annotation.*;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;

public class ChatException extends RuntimeException {
@Getter
@AllArgsConstructor
public enum Code {
CODE_BAD_REQUEST(400),
CODE_TOO_MANY_REQUESTS(429),
CODE_INTERNAL_SERVER_ERROR(500),
CODE_GATEWAY_TIMEOUT(504);

private final int code;
}

@Getter
private final int code;

public ChatException(String message) {
super(message);
this.code = Code.CODE_INTERNAL_SERVER_ERROR.getCode();
}

public ChatException(Code code, String message) {
super(message);
this.code = code.getCode();
}

public ChatException(int code, String message) {
super(message);
this.code = code;
}

public boolean is5xxServerError() {
return code >= 500 && code < 600;
}
}
55 changes: 52 additions & 3 deletions src/main/java/com/penguineering/hareairis/model/ChatRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,64 @@
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;

import java.util.Optional;

@NoArgsConstructor
@AllArgsConstructor
@Getter
@JsonIgnoreProperties(ignoreUnknown = true)
public class ChatRequest {
@JsonProperty("message")
private String message;
@JsonProperty("system-message")
private String systemMessage = "";

@JsonProperty("prompt")
private String prompt = "";

/**
* The maximum number of tokens to generate.
*/
@JsonProperty(value = "max-tokens")
private Integer maxTokens = null;

/**
* The sampling temperature to use that controls the apparent creativity of generated
* completions. Higher values will make output more random while lower values will
* make results more focused and deterministic. It is not recommended to modify
* temperature and top_p for the same completions request as the interaction of these
* two settings is difficult to predict.
*/
@JsonProperty(value = "temperature")
private Double temperature = null;

/**
* An alternative to sampling with temperature called nucleus sampling. This value
* causes the model to consider the results of tokens with the provided probability
* mass. As an example, a value of 0.15 will cause only the tokens comprising the top
* 15% of probability mass to be considered. It is not recommended to modify
* temperature and top_p for the same completions request as the interaction of these
* two settings is difficult to predict.
*/
@JsonProperty(value = "top-p")
private Double topP = null;

/**
* A value that influences the probability of generated tokens appearing based on
* their existing presence in generated text. Positive values will make tokens less
* likely to appear when they already exist and increase the model's likelihood to
* output new topics.
*/
@JsonProperty(value = "presence-penalty")
private Double presencePenalty = null;

/**
* A value that influences the probability of generated tokens appearing based on
* their cumulative frequency in generated text. Positive values will make tokens less
* likely to appear as their frequency increases and decrease the likelihood of the
* model repeating the same statements verbatim.
*/
@JsonProperty(value = "frequency-penalty")
private Double frequencyPenalty = null;
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.penguineering.hareairis.rmq;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.penguineering.hareairis.ai.AIChatService;
import com.penguineering.hareairis.model.ChatError;
import com.penguineering.hareairis.model.ChatException;
import com.penguineering.hareairis.model.ChatRequest;
import com.penguineering.hareairis.model.ChatResponse;
import com.rabbitmq.client.Channel;
Expand Down Expand Up @@ -76,7 +78,7 @@ public void onMessage(Message message, Channel channel) {
String replyTo = Optional
.ofNullable(message.getMessageProperties())
.map(MessageProperties::getReplyTo)
.orElseThrow(() -> new ChatError(ChatError.Code.CODE_BAD_REQUEST, "Reply_to property is missing"));
.orElseThrow(() -> new ChatException(ChatException.Code.CODE_BAD_REQUEST, "Reply_to property is missing"));
logger.info("Reply-to header: {}", replyTo);


Expand All @@ -95,6 +97,7 @@ public void onMessage(Message message, Channel channel) {
// Acknowledge the message
channel.basicAck(deliveryTag, false);
} catch (Exception e) {
logger.info("Error on chat request", e);
Optional<String> json = serializeChatError(e);
errorTo.ifPresentOrElse(
to -> json.ifPresent(
Expand All @@ -108,48 +111,56 @@ public void onMessage(Message message, Channel channel) {

private void doExceptionBasedAck(Exception e, Channel channel, long deliveryTag) {
try {
if (e instanceof ChatError chatError)
if (chatError.is5xxServerError())
// Acknowledge the message
channel.basicAck(deliveryTag, false);
if (e instanceof ChatException chatException)
if (chatException.is5xxServerError())
// Do not acknowledge the message
channel.basicNack(deliveryTag, false, true);

// Do not acknowledge the message
channel.basicNack(deliveryTag, false, true);
// Acknowledge the message
channel.basicAck(deliveryTag, false);
} catch (IOException ex) {
logger.error("Failed send message (n)ack!", ex);
}
}

private ChatRequest deserializeChatRequest(Message message) throws ChatError {
private ChatRequest deserializeChatRequest(Message message) throws ChatException {
try {
return objectMapper.readValue(message.getBody(), ChatRequest.class);
} catch (Exception e) {
throw new ChatError(ChatError.Code.CODE_BAD_REQUEST,
throw new ChatException(ChatException.Code.CODE_BAD_REQUEST,
"Failed to deserialize chat request: " + e.getMessage());
}
}

private String serializeChatResponse(ChatResponse response) throws ChatError {
private String serializeChatResponse(ChatResponse response) throws ChatException {
try {
return objectMapper.writeValueAsString(response);
} catch (Exception e) {
logger.error("Failed to serialize chat response", e);
throw new ChatError(ChatError.Code.CODE_INTERNAL_SERVER_ERROR,
throw new ChatException(ChatException.Code.CODE_INTERNAL_SERVER_ERROR,
"Failed to serialize chat response: " + e.getMessage());
}
}

private Optional<String> serializeChatError(Exception e) {
Optional<ChatError> error = e instanceof ChatError
? Optional.of((ChatError) e)
: Optional.of(new ChatError(e.getMessage()));
Optional<ChatException> chatEx = e instanceof ChatException
? Optional.of((ChatException) e)
: Optional.of(new ChatException(e.getMessage()));

try {
return objectMapper.writeValueAsString(error).describeConstable();
return chatEx
.map(ChatError::new)
.map(err -> {
try {
return objectMapper.writeValueAsString(err);
} catch (JsonProcessingException ex) {
throw new RuntimeException(ex);
}
});
//return objectMapper.writeValueAsString(error).describeConstable();
} catch (Exception ex) {
logger.error("Failed to serialize error", ex);
return Optional.empty();
}

}
}
Loading