Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-610] multi models support infrastructure #957

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ac12b8c
OPIK-610 initial structure
idoberko2 Dec 24, 2024
6fc37df
OPIK-610 remove unused classes
idoberko2 Dec 24, 2024
68df6bf
OPIK-610 fix injection
idoberko2 Dec 24, 2024
def0f49
OPIK-610 move client creation to provider class
idoberko2 Dec 24, 2024
07e8525
OPIK-610 removed unnecessary interface
idoberko2 Dec 24, 2024
3bfb047
OPIK-610 encapsulate openai specific error handling
idoberko2 Dec 25, 2024
7296801
OPIK-610 move retry logic to ChatCompletionService
idoberko2 Dec 25, 2024
d469df6
OPIK-610 move chucked output logic to ChatCompletionService
idoberko2 Dec 25, 2024
05c74ff
OPIK-610 pr comments
idoberko2 Dec 25, 2024
48c0e8a
OPIK-610 fix test
idoberko2 Dec 25, 2024
415a08f
OPIK-610 move additional responsibility to ChatCompletionService
idoberko2 Dec 25, 2024
f63b892
OPIK-610 add anthropic [WIP]
idoberko2 Dec 25, 2024
abdf8f1
OPIK-610 refactor error handling logic
idoberko2 Dec 25, 2024
ab60348
OPIK-610 anthropic streaming failing test
idoberko2 Dec 26, 2024
2f265d7
OPIK-610 factory get model
idoberko2 Dec 26, 2024
f39f1b1
OPIK-610 factory get model openai green
idoberko2 Dec 26, 2024
d083028
OPIK-610 factory get model anthropic green [WIP]
idoberko2 Dec 26, 2024
c1a7f79
OPIK-610 factory get model anthropic green
idoberko2 Dec 26, 2024
9686915
OPIK-610 anthropic streaming failing test [WIP]
idoberko2 Dec 26, 2024
daa45c2
OPIK-610 anthropic failing test
idoberko2 Dec 26, 2024
5a535a5
OPIK-610 improve error handling
idoberko2 Dec 26, 2024
027891a
OPIK-610 anthropic failing test green
idoberko2 Dec 27, 2024
203f2ce
OPIK-610 anthropic streaming failing test green
idoberko2 Dec 27, 2024
a9f0fc9
OPIK-610 refactor
idoberko2 Dec 27, 2024
be10cfa
OPIK-610 pr comments
idoberko2 Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,185 +1,112 @@
package com.comet.opik.domain;

import com.comet.opik.api.LlmProvider;
import com.comet.opik.infrastructure.EncryptionUtils;
import com.comet.opik.domain.llmproviders.LlmProviderFactory;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import com.comet.opik.utils.JsonUtils;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.internal.RetryUtils;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.ServerErrorException;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ChunkedOutput;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;

@Singleton
@Slf4j
public class ChatCompletionService {

private static final String UNEXPECTED_ERROR_CALLING_LLM_PROVIDER = "Unexpected error calling LLM provider";

private final LlmProviderClientConfig llmProviderClientConfig;
private final LlmProviderApiKeyService llmProviderApiKeyService;
private final LlmProviderFactory llmProviderFactory;
private final RetryUtils.RetryPolicy retryPolicy;

@Inject
public ChatCompletionService(
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig,
@NonNull LlmProviderApiKeyService llmProviderApiKeyService) {
this.llmProviderApiKeyService = llmProviderApiKeyService;
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig, LlmProviderFactory llmProviderFactory) {
this.llmProviderClientConfig = llmProviderClientConfig;
this.llmProviderFactory = llmProviderFactory;
this.retryPolicy = newRetryPolicy();
}

private RetryUtils.RetryPolicy newRetryPolicy() {
var retryPolicyBuilder = RetryUtils.retryPolicyBuilder();
Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts);
Optional.ofNullable(llmProviderClientConfig.getJitterScale()).ifPresent(retryPolicyBuilder::jitterScale);
Optional.ofNullable(llmProviderClientConfig.getBackoffExp()).ifPresent(retryPolicyBuilder::backoffExp);
return retryPolicyBuilder
.delayMillis(llmProviderClientConfig.getDelayMillis())
.build();
}

public ChatCompletionResponse create(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var openAiClient = getAndConfigureOpenAiClient(request, workspaceId);
ChatCompletionResponse chatCompletionResponse;
try {
chatCompletionResponse = retryPolicy.withRetry(() -> openAiClient.chatCompletion(request).execute());
} catch (RuntimeException runtimeException) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, runtimeException);
if (runtimeException.getCause() instanceof OpenAiHttpException openAiHttpException) {
if (openAiHttpException.code() >= 400 && openAiHttpException.code() <= 499) {
throw new ClientErrorException(openAiHttpException.getMessage(), openAiHttpException.code());
}
throw new ServerErrorException(openAiHttpException.getMessage(), openAiHttpException.code());
}
throw new InternalServerErrorException(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
}
var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model());
var chatCompletionResponse = retryPolicy.withRetry(() -> llmProviderClient.generate(request, workspaceId));
log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chatCompletionResponse;
}

public ChunkedOutput<String> createAndStreamResponse(
@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var openAiClient = getAndConfigureOpenAiClient(request, workspaceId);
var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model());

var chunkedOutput = new ChunkedOutput<String>(String.class, "\r\n");
openAiClient.chatCompletion(request)
.onPartialResponse(chatCompletionResponse -> send(chatCompletionResponse, chunkedOutput))
.onComplete(() -> close(chunkedOutput))
.onError(throwable -> handle(throwable, chunkedOutput))
.execute();
llmProviderClient.generateStream(
request,
workspaceId,
getMessageHandler(chunkedOutput),
getCloseHandler(chunkedOutput),
getErrorHandler(chunkedOutput, llmProviderClient::mapError));
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chunkedOutput;
}

private OpenAiClient getAndConfigureOpenAiClient(ChatCompletionRequest request, String workspaceId) {
var llmProvider = getLlmProvider(request.model());
var encryptedApiKey = getEncryptedApiKey(workspaceId, llmProvider);
return newOpenAiClient(encryptedApiKey);
}

/**
* The agreed requirement is to resolve the LLM provider and its API key based on the model.
* Currently, only OPEN AI is supported, so model param is ignored.
* No further validation is needed on the model, as it's just forwarded in the OPEN AI request and will be rejected
* if not valid.
*/
private LlmProvider getLlmProvider(String model) {
return LlmProvider.OPEN_AI;
}

/**
* Finding API keys isn't paginated at the moment, since only OPEN AI is supported.
* Even in the future, the number of supported LLM providers per workspace is going to be very low.
*/
private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) {
return llmProviderApiKeyService.find(workspaceId).content().stream()
.filter(providerApiKey -> llmProvider.equals(providerApiKey.provider()))
.findFirst()
.orElseThrow(() -> new BadRequestException("API key not configured for LLM provider '%s'".formatted(
llmProvider.getValue())))
.apiKey();
}

/**
* Initially, only OPEN AI is supported, so no need for a more sophisticated client resolution to start with.
* At the moment, openai4j client and also langchain4j wrappers, don't support dynamic API keys. That can imply
* an important performance penalty for next phases. The following options should be evaluated:
* - Cache clients, but can be unsafe.
* - Find and evaluate other clients.
* - Implement our own client.
* TODO as part of : <a href="https://comet-ml.atlassian.net/browse/OPIK-522">OPIK-522</a>
*/
private OpenAiClient newOpenAiClient(String encryptedApiKey) {
var openAiClientBuilder = OpenAiClient.builder();
Optional.ofNullable(llmProviderClientConfig.getOpenAiClient())
.map(LlmProviderClientConfig.OpenAiClientConfig::url)
.ifPresent(baseUrl -> {
if (StringUtils.isNotBlank(baseUrl)) {
openAiClientBuilder.baseUrl(baseUrl);
}
});
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> openAiClientBuilder.callTimeout(callTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getConnectTimeout())
.ifPresent(connectTimeout -> openAiClientBuilder.connectTimeout(connectTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getReadTimeout())
.ifPresent(readTimeout -> openAiClientBuilder.readTimeout(readTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getWriteTimeout())
.ifPresent(writeTimeout -> openAiClientBuilder.writeTimeout(writeTimeout.toJavaDuration()));
return openAiClientBuilder
.openAiApiKey(EncryptionUtils.decrypt(encryptedApiKey))
private RetryUtils.RetryPolicy newRetryPolicy() {
var retryPolicyBuilder = RetryUtils.retryPolicyBuilder();
Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts);
Optional.ofNullable(llmProviderClientConfig.getJitterScale()).ifPresent(retryPolicyBuilder::jitterScale);
Optional.ofNullable(llmProviderClientConfig.getBackoffExp()).ifPresent(retryPolicyBuilder::backoffExp);
return retryPolicyBuilder
.delayMillis(llmProviderClientConfig.getDelayMillis())
.build();
}

private void send(Object item, ChunkedOutput<String> chunkedOutput) {
if (chunkedOutput.isClosed()) {
log.warn("Output stream is already closed");
return;
}
try {
chunkedOutput.write(JsonUtils.writeValueAsString(item));
} catch (IOException ioException) {
throw new UncheckedIOException(ioException);
}
private <T> Consumer<T> getMessageHandler(ChunkedOutput<String> chunkedOutput) {
return item -> {
if (chunkedOutput.isClosed()) {
log.warn("Output stream is already closed");
return;
}
try {
chunkedOutput.write(JsonUtils.writeValueAsString(item));
} catch (IOException ioException) {
throw new UncheckedIOException(ioException);
}
};
}

private void handle(Throwable throwable, ChunkedOutput<String> chunkedOutput) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable);
var errorMessage = new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
if (throwable instanceof OpenAiHttpException openAiHttpException) {
errorMessage = new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage());
}
try {
send(errorMessage, chunkedOutput);
} catch (UncheckedIOException uncheckedIOException) {
log.error("Failed to stream error message to client", uncheckedIOException);
}
close(chunkedOutput);
private Runnable getCloseHandler(ChunkedOutput<String> chunkedOutput) {
return () -> {
try {
chunkedOutput.close();
} catch (IOException ioException) {
log.error("Failed to close output stream", ioException);
}
};
}

private void close(ChunkedOutput<String> chunkedOutput) {
try {
chunkedOutput.close();
} catch (IOException ioException) {
log.error("Failed to close output stream", ioException);
}
private Consumer<Throwable> getErrorHandler(
ChunkedOutput<String> chunkedOutput, Function<Throwable, ErrorMessage> errorMapper) {
return throwable -> {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable);

var errorMessage = errorMapper.apply(throwable);
try {
getMessageHandler(chunkedOutput).accept(errorMessage);
} catch (UncheckedIOException uncheckedIOException) {
log.error("Failed to stream error message to client", uncheckedIOException);
}
getCloseHandler(chunkedOutput).run();
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.comet.opik.domain.llmproviders;

import com.comet.opik.api.LlmProvider;
import com.comet.opik.domain.LlmProviderApiKeyService;
import com.comet.opik.infrastructure.EncryptionUtils;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.BadRequestException;
import lombok.NonNull;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

@Singleton
public class LlmProviderFactory {
private final LlmProviderClientConfig llmProviderClientConfig;
private final LlmProviderApiKeyService llmProviderApiKeyService;

@Inject
public LlmProviderFactory(
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig,
@NonNull LlmProviderApiKeyService llmProviderApiKeyService) {
this.llmProviderApiKeyService = llmProviderApiKeyService;
this.llmProviderClientConfig = llmProviderClientConfig;
}

public LlmProviderService getService(@NonNull String workspaceId, @NonNull String model) {
var llmProvider = getLlmProvider(model);
if (llmProvider == LlmProvider.OPEN_AI) {
var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider));
idoberko2 marked this conversation as resolved.
Show resolved Hide resolved
return new OpenAi(llmProviderClientConfig, apiKey);
}

throw new IllegalArgumentException("not supported provider " + llmProvider);
}

/**
* The agreed requirement is to resolve the LLM provider and its API key based on the model.
* Currently, only OPEN AI is supported, so model param is ignored.
* No further validation is needed on the model, as it's just forwarded in the OPEN AI request and will be rejected
* if not valid.
*/
private LlmProvider getLlmProvider(String model) {
return LlmProvider.OPEN_AI;
}

/**
* Finding API keys isn't paginated at the moment, since only OPEN AI is supported.
* Even in the future, the number of supported LLM providers per workspace is going to be very low.
*/
private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) {
return llmProviderApiKeyService.find(workspaceId).content().stream()
.filter(providerApiKey -> llmProvider.equals(providerApiKey.provider()))
.findFirst()
.orElseThrow(() -> new BadRequestException("API key not configured for LLM provider '%s'".formatted(
llmProvider.getValue())))
.apiKey();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.comet.opik.domain.llmproviders;

import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import io.dropwizard.jersey.errors.ErrorMessage;
import lombok.NonNull;

import java.util.function.Consumer;

public interface LlmProviderService {
ChatCompletionResponse generate(
@NonNull ChatCompletionRequest request,
@NonNull String workspaceId);

void generateStream(
@NonNull ChatCompletionRequest request,
@NonNull String workspaceId,
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose,
@NonNull Consumer<Throwable> handleError);

ErrorMessage mapError(Throwable throwable);
}
Loading
Loading