Skip to content

Commit

Permalink
OPIK-610 encapsulate openai specific error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
idoberko2 committed Dec 25, 2024
1 parent 07e8525 commit 3bfb047
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package com.comet.opik.domain.llmproviders;

import com.comet.opik.utils.JsonUtils;
import dev.ai4j.openai4j.OpenAiHttpException;
import io.dropwizard.jersey.errors.ErrorMessage;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.function.Consumer;
import java.util.function.Function;

@Slf4j
public class LlmProviderStreamHandler {
Expand All @@ -33,17 +34,18 @@ public void handleClose(ChunkedOutput<String> chunkedOutput) {
}
}

public void handleError(Throwable throwable, ChunkedOutput<String> chunkedOutput) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable);
var errorMessage = new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
if (throwable instanceof OpenAiHttpException openAiHttpException) {
errorMessage = new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage());
}
try {
handleMessage(errorMessage, chunkedOutput);
} catch (UncheckedIOException uncheckedIOException) {
log.error("Failed to stream error message to client", uncheckedIOException);
}
handleClose(chunkedOutput);
public Consumer<Throwable> getErrorHandler(
Function<Throwable, ErrorMessage> mapper, ChunkedOutput<String> chunkedOutput) {
return throwable -> {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable);

var errorMessage = mapper.apply(throwable);
try {
handleMessage(errorMessage, chunkedOutput);
} catch (UncheckedIOException uncheckedIOException) {
log.error("Failed to stream error message to client", uncheckedIOException);
}
handleClose(chunkedOutput);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.internal.RetryUtils;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.InternalServerErrorException;
Expand Down Expand Up @@ -61,7 +62,7 @@ public ChunkedOutput<String> generateStream(@NonNull ChatCompletionRequest reque
.onPartialResponse(
chatCompletionResponse -> streamHandler.handleMessage(chatCompletionResponse, chunkedOutput))
.onComplete(() -> streamHandler.handleClose(chunkedOutput))
.onError(throwable -> streamHandler.handleError(throwable, chunkedOutput))
.onError(streamHandler.getErrorHandler(this::errorMapper, chunkedOutput))
.execute();
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chunkedOutput;
Expand Down Expand Up @@ -97,4 +98,12 @@ private OpenAiClient newOpenAiClient(String apiKey) {
.openAiApiKey(apiKey)
.build();
}

private ErrorMessage errorMapper(Throwable throwable) {
if (throwable instanceof OpenAiHttpException openAiHttpException) {
return new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage());
}

return new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
}
}

0 comments on commit 3bfb047

Please sign in to comment.