Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-610] multi models support infrastructure #957

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ac12b8c
OPIK-610 initial structure
idoberko2 Dec 24, 2024
6fc37df
OPIK-610 remove unused classes
idoberko2 Dec 24, 2024
68df6bf
OPIK-610 fix injection
idoberko2 Dec 24, 2024
def0f49
OPIK-610 move client creation to provider class
idoberko2 Dec 24, 2024
07e8525
OPIK-610 removed unnecessary interface
idoberko2 Dec 24, 2024
3bfb047
OPIK-610 encapsulate openai specific error handling
idoberko2 Dec 25, 2024
7296801
OPIK-610 move retry logic to ChatCompletionService
idoberko2 Dec 25, 2024
d469df6
OPIK-610 move chucked output logic to ChatCompletionService
idoberko2 Dec 25, 2024
05c74ff
OPIK-610 pr comments
idoberko2 Dec 25, 2024
48c0e8a
OPIK-610 fix test
idoberko2 Dec 25, 2024
415a08f
OPIK-610 move additional responsibility to ChatCompletionService
idoberko2 Dec 25, 2024
f63b892
OPIK-610 add anthropic [WIP]
idoberko2 Dec 25, 2024
abdf8f1
OPIK-610 refactor error handling logic
idoberko2 Dec 25, 2024
ab60348
OPIK-610 anthropic streaming failing test
idoberko2 Dec 26, 2024
2f265d7
OPIK-610 factory get model
idoberko2 Dec 26, 2024
f39f1b1
OPIK-610 factory get model openai green
idoberko2 Dec 26, 2024
d083028
OPIK-610 factory get model anthropic green [WIP]
idoberko2 Dec 26, 2024
c1a7f79
OPIK-610 factory get model anthropic green
idoberko2 Dec 26, 2024
9686915
OPIK-610 anthropic streaming failing test [WIP]
idoberko2 Dec 26, 2024
daa45c2
OPIK-610 anthropic failing test
idoberko2 Dec 26, 2024
5a535a5
OPIK-610 improve error handling
idoberko2 Dec 26, 2024
027891a
OPIK-610 anthropic failing test green
idoberko2 Dec 27, 2024
203f2ce
OPIK-610 anthropic streaming failing test green
idoberko2 Dec 27, 2024
a9f0fc9
OPIK-610 refactor
idoberko2 Dec 27, 2024
be10cfa
OPIK-610 pr comments
idoberko2 Dec 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions apps/opik-backend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-anthropic</artifactId>
</dependency>

<!-- Test -->

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.domain;

import com.comet.opik.domain.llmproviders.LlmProviderFactory;
import com.comet.opik.domain.llmproviders.LlmProviderService;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import com.comet.opik.utils.JsonUtils;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
Expand All @@ -9,6 +10,9 @@
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.ServerErrorException;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;
Expand All @@ -18,7 +22,6 @@
import java.io.UncheckedIOException;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;

@Singleton
@Slf4j
Expand All @@ -43,10 +46,20 @@ public ChatCompletionResponse create(@NonNull ChatCompletionRequest request, @No

ChatCompletionResponse chatCompletionResponse;
try {
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
chatCompletionResponse = retryPolicy.withRetry(() -> llmProviderClient.generate(request, workspaceId));
log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
} catch (RuntimeException runtimeException) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, runtimeException);
throw llmProviderClient.mapRuntimeException(runtimeException);
if (llmProviderClient.getHttpExceptionClass().isInstance(runtimeException.getCause())) {
int statusCode = llmProviderClient.getHttpErrorStatusCode(runtimeException);
if (statusCode >= 400 && statusCode <= 499) {
idoberko2 marked this conversation as resolved.
Show resolved Hide resolved
throw new ClientErrorException(runtimeException.getMessage(), statusCode);
}

throw new ServerErrorException(runtimeException.getMessage(), statusCode);
}
throw new InternalServerErrorException(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
}

log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
Expand All @@ -59,12 +72,13 @@ public ChunkedOutput<String> createAndStreamResponse(
var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model());

var chunkedOutput = new ChunkedOutput<String>(String.class, "\r\n");
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
llmProviderClient.generateStream(
request,
workspaceId,
getMessageHandler(chunkedOutput),
getCloseHandler(chunkedOutput),
getErrorHandler(chunkedOutput, llmProviderClient::mapThrowableToError));
getErrorHandler(chunkedOutput, llmProviderClient));
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chunkedOutput;
}
Expand Down Expand Up @@ -104,11 +118,16 @@ private Runnable getCloseHandler(ChunkedOutput<String> chunkedOutput) {
}

private Consumer<Throwable> getErrorHandler(
ChunkedOutput<String> chunkedOutput, Function<Throwable, ErrorMessage> errorMapper) {
ChunkedOutput<String> chunkedOutput, LlmProviderService llmProviderClient) {
return throwable -> {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable);

var errorMessage = errorMapper.apply(throwable);
var errorMessage = new ErrorMessage(ChatCompletionService.UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
if (llmProviderClient.getHttpExceptionClass().isInstance(throwable)) {
errorMessage = new ErrorMessage(llmProviderClient.getHttpErrorStatusCode(throwable),
throwable.getMessage());
}

try {
getMessageHandler(chunkedOutput).accept(errorMessage);
} catch (UncheckedIOException uncheckedIOException) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package com.comet.opik.domain.llmproviders;

import com.comet.opik.infrastructure.LlmProviderClientConfig;
import dev.ai4j.openai4j.chat.AssistantMessage;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Message;
import dev.ai4j.openai4j.chat.Role;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicRole;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
import dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException;
import dev.langchain4j.model.output.Response;
import jakarta.ws.rs.BadRequestException;
import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;

import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

public class Anthropic implements LlmProviderService {
private final LlmProviderClientConfig llmProviderClientConfig;
private final AnthropicClient anthropicClient;

public Anthropic(LlmProviderClientConfig llmProviderClientConfig, String apiKey) {
this.llmProviderClientConfig = llmProviderClientConfig;
this.anthropicClient = newClient(apiKey);
}

@Override
public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
var response = anthropicClient.createMessage(AnthropicCreateMessageRequest.builder()
.model(request.model())
.messages(request.messages().stream().map(this::mapMessage).toList())
.temperature(request.temperature())
.topP(request.topP())
.stopSequences(request.stop())
.maxTokens(request.maxTokens())
.toolChoice(AnthropicToolChoice.from(request.toolChoice().toString()))
.build());

return ChatCompletionResponse.builder()
.id(response.id)
.model(response.model)
.choices(response.content.stream().map(content -> mapContentToChoice(response, content))
.toList())
.usage(Usage.builder()
.promptTokens(response.usage.inputTokens)
.completionTokens(response.usage.outputTokens)
.totalTokens(response.usage.inputTokens + response.usage.outputTokens)
.build())
.build();
}

@Override
public void generateStream(
@NonNull ChatCompletionRequest request,
@NonNull String workspaceId,
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose, @NonNull Consumer<Throwable> handleError) {
return anthropicClient.createMessage(AnthropicCreateMessageRequest.builder()
.model(request.model())
.messages(request.messages().stream().map(this::mapMessage).toList())
.temperature(request.temperature())
.topP(request.topP())
.stopSequences(request.stop())
.maxTokens(request.maxTokens())
.toolChoice(AnthropicToolChoice.from(request.toolChoice().toString()))
.build(), new ChunkedResponseHandler(handleMessage, handleClose, handleError));

return ChatCompletionResponse.builder()
.id(response.id)
.model(response.model)
.choices(response.content.stream().map(content -> mapContentToChoice(response, content))
.toList())
.usage(Usage.builder()
.promptTokens(response.usage.inputTokens)
.completionTokens(response.usage.outputTokens)
.totalTokens(response.usage.inputTokens + response.usage.outputTokens)
.build())
.build();
}

@Override
public Class<? extends Throwable> getHttpExceptionClass() {
return AnthropicHttpException.class;
}

@Override
public int getHttpErrorStatusCode(Throwable throwable) {
if (throwable instanceof AnthropicHttpException anthropicHttpException) {
return anthropicHttpException.statusCode();
}

return 500;
}

private AnthropicMessage mapMessage(Message message) {
if (message.role() == Role.ASSISTANT) {
return AnthropicMessage.builder()
.role(AnthropicRole.ASSISTANT)
.content(List.of(new AnthropicTextContent(((AssistantMessage) message).content())))
.build();
} else if (message.role() == Role.USER) {
return AnthropicMessage.builder()
.role(AnthropicRole.USER)
.content(((UserMessage) message).contents().stream().map(this::mapMessageContent).toList())
.build();
}

// Anthropic only supports assistant and user roles
throw new BadRequestException("not supported message role: " + message.role());
}

private AnthropicMessageContent mapMessageContent(Content content) {
if (content instanceof TextContent) {
return new AnthropicTextContent(((TextContent) content).text());
}

throw new BadRequestException("only text content is supported");
}

private ChatCompletionChoice mapContentToChoice(AnthropicCreateMessageResponse response, AnthropicContent content) {
return ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.name(content.name)
.content(content.text)
.build())
.finishReason(response.stopReason)
.build();
}

private AnthropicClient newClient(String apiKey) {
var anthropicClientBuilder = AnthropicClient.builder();
Optional.ofNullable(llmProviderClientConfig.getOpenAiClient())
.map(LlmProviderClientConfig.OpenAiClientConfig::url)
.ifPresent(baseUrl -> {
if (StringUtils.isNotBlank(baseUrl)) {
anthropicClientBuilder.baseUrl(baseUrl);
}
});
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration()));
return anthropicClientBuilder
.apiKey(apiKey)
.build();
}

private static class ChunkedResponseHandler implements StreamingResponseHandler<AiMessage> {
private final Consumer<ChatCompletionResponse> handleMessage;
private final Runnable handleClose;
private final Consumer<Throwable> handleError;

public ChunkedResponseHandler(
Consumer<ChatCompletionResponse> handleMessage, Runnable handleClose, Consumer<Throwable> handleError) {
this.handleMessage = handleMessage;
this.handleClose = handleClose;
this.handleError = handleError;
}

@Override
public void onNext(String s) {

}

@Override
public void onComplete(Response<AiMessage> response) {
StreamingResponseHandler.super.onComplete(response);
}

@Override
public void onError(Throwable throwable) {
handleError.accept(throwable);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.ws.rs.WebApplicationException;
import lombok.NonNull;

import java.util.function.Consumer;
Expand All @@ -20,7 +18,7 @@ void generateStream(
@NonNull Runnable handleClose,
@NonNull Consumer<Throwable> handleError);

WebApplicationException mapRuntimeException(RuntimeException runtimeException);
Class<? extends Throwable> getHttpExceptionClass();

ErrorMessage mapThrowableToError(Throwable throwable);
int getHttpErrorStatusCode(Throwable runtimeException);
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
package com.comet.opik.domain.llmproviders;

import com.comet.opik.domain.ChatCompletionService;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.ServerErrorException;
import jakarta.ws.rs.WebApplicationException;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
Expand All @@ -32,10 +26,7 @@ public OpenAi(LlmProviderClientConfig llmProviderClientConfig, String apiKey) {

@Override
public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
idoberko2 marked this conversation as resolved.
Show resolved Hide resolved
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var chatCompletionResponse = openAiClient.chatCompletion(request).execute();;
log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chatCompletionResponse;
return openAiClient.chatCompletion(request).execute();
}

@Override
Expand All @@ -45,35 +36,25 @@ public void generateStream(
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose,
@NonNull Consumer<Throwable> handleError) {
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
openAiClient.chatCompletion(request)
.onPartialResponse(handleMessage)
.onComplete(handleClose)
.onError(handleError)
.execute();
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
}

@Override
public WebApplicationException mapRuntimeException(RuntimeException runtimeException) {
if (runtimeException.getCause() instanceof OpenAiHttpException openAiHttpException) {
if (openAiHttpException.code() >= 400 && openAiHttpException.code() <= 499) {
return new ClientErrorException(openAiHttpException.getMessage(), openAiHttpException.code());
}

return new ServerErrorException(openAiHttpException.getMessage(), openAiHttpException.code());
}

return new InternalServerErrorException(ChatCompletionService.UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
public Class<? extends RuntimeException> getHttpExceptionClass() {
return OpenAiHttpException.class;
}

@Override
public ErrorMessage mapThrowableToError(Throwable throwable) {
public int getHttpErrorStatusCode(Throwable throwable) {
if (throwable instanceof OpenAiHttpException openAiHttpException) {
return new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage());
return openAiHttpException.code();
}

return new ErrorMessage(ChatCompletionService.UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
return 500;
}

/**
Expand Down
Loading