Skip to content

Commit

Permalink
Merge pull request #3 from penguineer/mvp-refinement
Browse files Browse the repository at this point in the history
Refine the API and architecture
  • Loading branch information
penguineer authored Oct 30, 2024
2 parents 8cb1aed + 2193ac4 commit e04704e
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 48 deletions.
51 changes: 51 additions & 0 deletions src/main/java/com/penguineering/hareairis/ai/AIChatService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.penguineering.hareairis.ai;

import com.penguineering.hareairis.model.ChatRequest;
import com.penguineering.hareairis.model.ChatResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.stereotype.Service;

/**
* Service to handle chat requests.
*
* <p>Uses the OpenAI ChatClient to handle chat requests.</p>
*/
@Service
public class AIChatService {
private static final Logger logger = LoggerFactory.getLogger(AIChatService.class);
private final ChatClient.Builder chatClientBuilder;

public AIChatService(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
}

/**
* Handles a chat request.
*
* @param chatRequest The chat request to handle.
* @return The chat response.
*/
public ChatResponse handleChatRequest(ChatRequest chatRequest) {
ChatClient chatClient = chatClientBuilder.build();
var chatResponse = chatClient
.prompt()
.user(chatRequest.getMessage())
.call()
.chatResponse();

String response = chatResponse.getResult().getOutput().getContent();

logger.info("Received response from OpenAI: {}", response);

Long promptTokens = chatResponse.getMetadata().getUsage().getPromptTokens();
Long generationTokens = chatResponse.getMetadata().getUsage().getGenerationTokens();

return ChatResponse.builder()
.response(response)
.inputTokens(promptTokens.intValue())
.outputTokens(generationTokens.intValue())
.build();
}
}
46 changes: 46 additions & 0 deletions src/main/java/com/penguineering/hareairis/model/ChatError.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.penguineering.hareairis.model;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;

@NoArgsConstructor
@Getter
@JsonIgnoreProperties(ignoreUnknown = true)
public class ChatError extends RuntimeException {
@Getter
@AllArgsConstructor
public enum Code {
CODE_BAD_REQUEST(400),
CODE_TOO_MANY_REQUESTS(429),
CODE_INTERNAL_SERVER_ERROR(500),
CODE_GATEWAY_TIMEOUT(504);

private final int code;

}

@JsonProperty("code")
private int code;

@JsonProperty("message")
private String message;

public ChatError(String message) {
super(message);
this.code = Code.CODE_INTERNAL_SERVER_ERROR.getCode();
this.message = message;
}

public ChatError(Code code, String message) {
super(message);
this.code = code.getCode();
this.message = message;
}

public boolean is5xxServerError() {
return code >= 500 && code < 600;
}
}
10 changes: 4 additions & 6 deletions src/main/java/com/penguineering/hareairis/model/ChatRequest.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
package com.penguineering.hareairis.model;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;

import java.util.UUID;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Getter
@JsonIgnoreProperties(ignoreUnknown = true)
public class ChatRequest {
@JsonProperty("request_id")
private UUID requestId;

@JsonProperty("message")
private String message;
}
20 changes: 11 additions & 9 deletions src/main/java/com/penguineering/hareairis/model/ChatResponse.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
package com.penguineering.hareairis.model;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.*;

import java.util.UUID;

@Data
@NoArgsConstructor
@AllArgsConstructor
@Getter
@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
public class ChatResponse {
@JsonProperty("request_id")
private UUID requestId;

@JsonProperty("response")
private String response;

@JsonProperty("input-tokens")
private int inputTokens;

@JsonProperty("output-tokens")
private int outputTokens;
}
138 changes: 117 additions & 21 deletions src/main/java/com/penguineering/hareairis/rmq/ChatRequestHandler.java
Original file line number Diff line number Diff line change
@@ -1,59 +1,155 @@
package com.penguineering.hareairis.rmq;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.penguineering.hareairis.ai.AIChatService;
import com.penguineering.hareairis.model.ChatError;
import com.penguineering.hareairis.model.ChatRequest;
import com.penguineering.hareairis.model.ChatResponse;
import com.rabbitmq.client.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.amqp.core.Message;
import org.springframework.amqp.core.MessageProperties;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.amqp.rabbit.listener.api.ChannelAwareMessageListener;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.util.Optional;

/**
* Handles chat requests from RabbitMQ.
*
* <p>Handles chat requests from RabbitMQ, processes them using the AiChatService and sends the response back to the
* replyTo queue.</p>
*/
@Component
public class ChatRequestHandler {
public class ChatRequestHandler implements ChannelAwareMessageListener {

private static final Logger logger = LoggerFactory.getLogger(ChatRequestHandler.class);
private final ObjectMapper objectMapper;
private final ChatClient.Builder chatClientBuilder;
private final AIChatService aiChatService;
private final RabbitTemplate rabbitTemplate;

public ChatRequestHandler(ObjectMapper objectMapper,
ChatClient.Builder builder,
AIChatService aiChatService,
RabbitTemplate rabbitTemplate) {
this.objectMapper = objectMapper;
this.chatClientBuilder = builder;
this.aiChatService = aiChatService;
this.rabbitTemplate = rabbitTemplate;
}

public void handleMessage(Message message) {

/**
* Handles a chat request.
*
* <p>Handles a chat request, processes it using the AiChatService and sends the response back to the replyTo queue.</p>
*
* @param message The chat request message.
*/
@Override
public void onMessage(Message message, Channel channel) {
long deliveryTag = message.getMessageProperties().getDeliveryTag();

// Extract the correlation ID
Optional<String> correlationId = Optional
.ofNullable(message.getMessageProperties())
.map(MessageProperties::getCorrelationId);
correlationId.ifPresentOrElse(
id -> logger.info("Received a chat request with Correlation ID: {}", id),
() -> logger.warn("Received a chat request without Correlation ID")
);

// Extract the custom error queue header
Optional<String> errorTo = Optional
.ofNullable(message.getMessageProperties())
.map(props -> props.getHeader("error_to"))
.map(String.class::cast);
if (errorTo.isEmpty())
logger.warn("Error_to header not provided, errors will be logged only!");

try {
logger.info("Received message: {}", new String(message.getBody()));
ChatRequest chatRequest = objectMapper.readValue(message.getBody(), ChatRequest.class);
ChatRequest chatRequest = deserializeChatRequest(message);

// Extract the "reply-to" header
MessageProperties properties = message.getMessageProperties();
String replyTo = properties.getReplyTo();
// Extract the "reply_to" property
String replyTo = Optional
.ofNullable(message.getMessageProperties())
.map(MessageProperties::getReplyTo)
.orElseThrow(() -> new ChatError(ChatError.Code.CODE_BAD_REQUEST, "Reply_to property is missing"));
logger.info("Reply-to header: {}", replyTo);

ChatClient chatClient = chatClientBuilder.build();
String response = chatClient
.prompt(chatRequest.getMessage())
.call()
.content();

logger.info("Received response from OpenAI: {}", response);

ChatResponse chatResponse = new ChatResponse(chatRequest.getRequestId(), response);
ChatResponse result = aiChatService.handleChatRequest(chatRequest);

// Convert ChatResponse to JSON
String jsonResponse = objectMapper.writeValueAsString(chatResponse);
String jsonResponse = serializeChatResponse(result);

// Send the response to the replyTo queue
rabbitTemplate.convertAndSend(replyTo, jsonResponse);
MessageProperties messageProperties = new MessageProperties();
correlationId.ifPresent(messageProperties::setCorrelationId);
messageProperties.setContentType("application/json");
Message responseMessage = new Message(jsonResponse.getBytes(), messageProperties);
rabbitTemplate.send(replyTo, responseMessage);

// Acknowledge the message
channel.basicAck(deliveryTag, false);
} catch (Exception e) {
logger.error("Failed to process message", e);
Optional<String> json = serializeChatError(e);
errorTo.ifPresentOrElse(
to -> json.ifPresent(
j -> rabbitTemplate.convertAndSend(to, j)),
() -> logger.error("Error on handling chat request!", e)
);

doExceptionBasedAck(e, channel, deliveryTag);
}
}

private void doExceptionBasedAck(Exception e, Channel channel, long deliveryTag) {
try {
if (e instanceof ChatError chatError)
if (chatError.is5xxServerError())
// Acknowledge the message
channel.basicAck(deliveryTag, false);

// Do not acknowledge the message
channel.basicNack(deliveryTag, false, true);
} catch (IOException ex) {
logger.error("Failed send message (n)ack!", ex);
}
}

private ChatRequest deserializeChatRequest(Message message) throws ChatError {
try {
return objectMapper.readValue(message.getBody(), ChatRequest.class);
} catch (Exception e) {
throw new ChatError(ChatError.Code.CODE_BAD_REQUEST,
"Failed to deserialize chat request: " + e.getMessage());
}
}

private String serializeChatResponse(ChatResponse response) throws ChatError {
try {
return objectMapper.writeValueAsString(response);
} catch (Exception e) {
logger.error("Failed to serialize chat response", e);
throw new ChatError(ChatError.Code.CODE_INTERNAL_SERVER_ERROR,
"Failed to serialize chat response: " + e.getMessage());
}
}

private Optional<String> serializeChatError(Exception e) {
Optional<ChatError> error = e instanceof ChatError
? Optional.of((ChatError) e)
: Optional.of(new ChatError(e.getMessage()));

try {
return objectMapper.writeValueAsString(error).describeConstable();
} catch (Exception ex) {
logger.error("Failed to serialize error", ex);
return Optional.empty();
}

}
}
17 changes: 5 additions & 12 deletions src/main/java/com/penguineering/hareairis/rmq/RabbitMQConfig.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package com.penguineering.hareairis.rmq;

import org.springframework.amqp.core.AcknowledgeMode;
import org.springframework.amqp.core.Queue;
import org.springframework.amqp.rabbit.annotation.EnableRabbit;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.listener.SimpleMessageListenerContainer;
import org.springframework.amqp.rabbit.listener.adapter.MessageListenerAdapter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.amqp.support.converter.MessageConverter;
import org.springframework.amqp.support.converter.SimpleMessageConverter;

@Configuration
@EnableRabbit
Expand All @@ -25,18 +23,13 @@ public Queue chatRequestsQueue() {

@Bean
public SimpleMessageListenerContainer chatRequestsContainer(ConnectionFactory connectionFactory,
MessageListenerAdapter listenerAdapter) {
ChatRequestHandler handler) {
SimpleMessageListenerContainer container = new SimpleMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
container.setQueueNames(queueChatRequests);
container.setMessageListener(listenerAdapter);
container.setMessageListener(handler);
container.setAcknowledgeMode(AcknowledgeMode.MANUAL);
container.setChannelTransacted(true);
return container;
}

@Bean
public MessageListenerAdapter chatRequestsListenerAdapter(ChatRequestHandler handler) {
MessageListenerAdapter adapter = new MessageListenerAdapter(handler, "handleMessage");
adapter.setMessageConverter(null); // Ensure the whole message is passed
return adapter;
}
}

0 comments on commit e04704e

Please sign in to comment.