Skip to content

Commit

Permalink
Updated to support CDI creation for RAG support.
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEliteGentleman committed Dec 14, 2024
1 parent 95d65d2 commit 0cca08b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,100 @@
import java.util.ArrayList;
import java.util.List;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;

import org.jboss.logging.Logger;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import io.smallrye.llm.core.langchain4j.core.config.spi.ChatMemoryFactoryProvider;
import io.smallrye.llm.spi.RegisterAIService;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;

public class CommonAIServiceCreator {

private static final Logger LOGGER = Logger.getLogger(CommonAIServiceCreator.class);

@SuppressWarnings("unchecked")
public static Object create(Instance<Object> lookup, Class<?> interfaceClass) {
public static <X> X create(Instance<Object> lookup, Class<X> interfaceClass) {
RegisterAIService annotation = interfaceClass.getAnnotation(RegisterAIService.class);
Instance<ChatLanguageModel> chatLanguageModel = getInstance(lookup, ChatLanguageModel.class,
annotation.chatLanguageModelName());
Instance<StreamingChatLanguageModel> streamingChatLanguageModel = getInstance(lookup, StreamingChatLanguageModel.class,
annotation.streamingChatLanguageModelName());
Instance<ContentRetriever> contentRetriever = getInstance(lookup, ContentRetriever.class,
annotation.contentRetrieverName());
Instance<RetrievalAugmentor> retrievalAugmentor = getInstance(lookup, RetrievalAugmentor.class,
annotation.retrievalAugmentorName());
try {
AiServices<?> aiServices = AiServices.builder(interfaceClass);
if (chatLanguageModel.isResolvable()) {
AiServices<X> aiServices = AiServices.builder(interfaceClass);
if (chatLanguageModel != null && chatLanguageModel.isResolvable()) {
LOGGER.info("ChatLanguageModel " + chatLanguageModel.get());
aiServices.chatLanguageModel(chatLanguageModel.get());
}
if (contentRetriever.isResolvable()) {
if (streamingChatLanguageModel != null && streamingChatLanguageModel.isResolvable()) {
LOGGER.info("StreamingChatLanguageModel " + streamingChatLanguageModel.get());
aiServices.streamingChatLanguageModel(streamingChatLanguageModel.get());
}
if (contentRetriever != null && contentRetriever.isResolvable()) {
LOGGER.info("ContentRetriever " + contentRetriever.get());
aiServices.contentRetriever(contentRetriever.get());
}
if (retrievalAugmentor != null && retrievalAugmentor.isResolvable()) {
LOGGER.info("RetrievalAugmentor " + retrievalAugmentor.get());
aiServices.retrievalAugmentor(retrievalAugmentor.get());
}
if (annotation.tools() != null && annotation.tools().length > 0) {
List<Object> tools = new ArrayList<>(annotation.tools().length);
for (Class toolClass : annotation.tools()) {
try {
tools.add(toolClass.getConstructor(null).newInstance(null));
tools.add(toolClass.getConstructor((Class<?>[])null).newInstance((Object[])null));
} catch (NoSuchMethodException | SecurityException | InstantiationException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException ex) {
}
}
aiServices.tools(tools);
}
aiServices.chatMemory(
ChatMemoryFactoryProvider.getChatMemoryFactory().getChatMemory(lookup, annotation.chatMemoryMaxMessages()));

Instance<ChatMemory> chatMemory = getInstance(lookup, ChatMemory.class,
annotation.chatMemoryName());
if (chatMemory != null && chatMemory.isResolvable()) {
LOGGER.info("ChatMemory " + chatMemory.get());
aiServices.chatMemory(chatMemory.get());
}

Instance<ChatMemoryProvider> chatMemoryProvider = getInstance(lookup, ChatMemoryProvider.class,
annotation.chatMemoryProviderName());
if (chatMemoryProvider != null && chatMemoryProvider.isResolvable()) {
LOGGER.info("ChatMemoryProvider " + chatMemoryProvider.get());
aiServices.chatMemoryProvider(chatMemoryProvider.get());
}

Instance<ModerationModel> moderationModelInstance = getInstance(lookup, ModerationModel.class,
annotation.moderationModelName());
if (moderationModelInstance != null && moderationModelInstance.isResolvable()) {
LOGGER.info("ModerationModel " + moderationModelInstance.get());
aiServices.moderationModel(moderationModelInstance.get());
}

return aiServices.build();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private static Instance getInstance(Instance<Object> lookup, Class<?> type, String name) {
LOGGER.info("Getinstance of '" + type + "' with name '" + name + "'");
if (name == null || name.isBlank()) {
return lookup.select(type);
private static <X> Instance<X> getInstance(Instance<Object> lookup, Class<X> type, String name) {
LOGGER.info("CDI get instance of type '" + type + "' with name '" + name + "'");
if (name != null && !name.isBlank()) {
if ("#default".equals(name))
return lookup.select(type);

return lookup.select(type, NamedLiteral.of(name));
}
return lookup.select(type, NamedLiteral.of(name));

return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;
import jakarta.enterprise.inject.spi.CDI;

import org.jboss.logging.Logger;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import io.smallrye.llm.core.langchain4j.core.config.spi.LLMConfig;
import io.smallrye.llm.core.langchain4j.core.config.spi.LLMConfigProvider;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;
import jakarta.enterprise.inject.spi.CDI;

/*
smallrye.llm.plugin.content-retriever.class=dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever
Expand Down Expand Up @@ -130,7 +134,33 @@ public static Object create(Instance<Object> lookup, String beanName, Class<?> t
} else {
for (Method methodToCall : methodsToCall) {
Class<?> parameterType = methodToCall.getParameterTypes()[0];
if (stringValue.startsWith("lookup:")) {
if ("listeners".equals(property)) {
Class<?> typeParameterClass = ChatLanguageModel.class.isAssignableFrom(targetClass) || StreamingChatLanguageModel.class.isAssignableFrom(targetClass)
? ChatModelListener.class
: parameterType.getTypeParameters()[0].getGenericDeclaration();
List<Object> listeners = (List<Object>) Collections.checkedList(new ArrayList<>(),
typeParameterClass);
if ("@all".equals(stringValue.trim())) {
Instance<Object> inst = (Instance<Object>) getInstance(lookup, typeParameterClass);
if (inst != null) {
inst.forEach(listeners::add);
}
} else {
try {
for (String className : stringValue.split(",")) {
Instance<?> inst = getInstance(lookup, loadClass(className.trim()));
listeners.add(inst.get());
}
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}

if (listeners != null && !listeners.isEmpty()) {
listeners.stream().forEach(l -> LOGGER.info("Adding listener: " + l.getClass().getName()));
methodToCall.invoke(builder, listeners);
}
} else if (stringValue.startsWith("lookup:")) {
String lookupableBean = stringValue.substring("lookup:".length());
LOGGER.info("Lookup " + lookupableBean + " " + parameterType);
Instance<?> inst;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@

Class<?>[] tools() default {};

String chatLanguageModelName() default "";
String chatLanguageModelName() default "#default";

String streamingChatLanguageModelName() default "";

String contentRetrieverModelName() default "";

int chatMemoryMaxMessages() default 10;

String embeddingModelName() default "";
String contentRetrieverName() default "";

String embeddingStoreName() default "";
String moderationModelName() default "";

String contentRetrieverName() default "";
String chatMemoryName() default "";

String chatMemoryProviderName() default "";

String retrievalAugmentorName() default "";
}

0 comments on commit 0cca08b

Please sign in to comment.