From aeedf91676385a7d720f98ab702dfbfdd7ff31c2 Mon Sep 17 00:00:00 2001 From: Victor Martin Date: Mon, 11 Nov 2024 21:21:36 +0100 Subject: [PATCH] support meta models --- app/src/components/content/settings.tsx | 5 +- .../backend/controller/GenAIController.java | 18 ++-- .../backend/controller/PromptController.java | 2 +- .../backend/service/GenAIModelsService.java | 44 +++++++++ .../backend/service/OCIGenAIService.java | 98 ++++++++++++++----- 5 files changed, 127 insertions(+), 40 deletions(-) create mode 100644 backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/GenAIModelsService.java diff --git a/app/src/components/content/settings.tsx b/app/src/components/content/settings.tsx index 269f66d3..babb9f71 100644 --- a/app/src/components/content/settings.tsx +++ b/app/src/components/content/settings.tsx @@ -81,9 +81,8 @@ export const Settings = (props: Props) => { const json = await response.json(); const result = json.filter((model: Model) => { if ( - model.capabilities.includes("TEXT_GENERATION") && - (model.vendor == "cohere" || model.vendor == "") && - model.version != "14.2" + model.capabilities.includes("CHAT") && + (model.vendor == "cohere" || model.vendor == "meta") ) return model; }); diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java index a2e2e51f..5a3f69f8 100644 --- a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java @@ -8,6 +8,7 @@ import com.oracle.bmc.generativeai.responses.ListEndpointsResponse; import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel; import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint; +import dev.victormartin.oci.genai.backend.backend.service.GenAIModelsService; import dev.victormartin.oci.genai.backend.backend.service.GenAiClientService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,19 +30,16 @@ public class GenAIController { @Autowired private GenAiClientService generativeAiClientService; + @Autowired + private GenAIModelsService genAIModelsService; + @GetMapping("/api/genai/models") public List getModels() { logger.info("getModels()"); - ListModelsRequest listModelsRequest = ListModelsRequest.builder().compartmentId(COMPARTMENT_ID).build(); - GenerativeAiClient client = generativeAiClientService.getClient(); - ListModelsResponse response = client.listModels(listModelsRequest); - return response.getModelCollection().getItems().stream().map(m -> { - List capabilities = m.getCapabilities().stream().map(ModelCapability::getValue) - .collect(Collectors.toList()); - GenAiModel model = new GenAiModel(m.getId(), m.getDisplayName(), m.getVendor(), m.getVersion(), - capabilities, m.getTimeCreated()); - return model; - }).collect(Collectors.toList()); + List models = genAIModelsService.getModels(); + return models.stream() + .filter(m -> m.capabilities().contains("CHAT")) + .collect(Collectors.toList()); } @GetMapping("/api/genai/endpoints") diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java index c7005f74..1adb7080 100644 --- a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java @@ -61,7 +61,7 @@ public Answer handlePrompt(Prompt prompt) { throw new InvalidPromptRequest(); } saved.setDatetimeResponse(new Date()); - String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune); + String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune, false); saved.setResponse(responseFromGenAI); interactionRepository.save(saved); return new Answer(responseFromGenAI, ""); diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/GenAIModelsService.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/GenAIModelsService.java new file mode 100644 index 00000000..b52c4aee --- /dev/null +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/GenAIModelsService.java @@ -0,0 +1,44 @@ +package dev.victormartin.oci.genai.backend.backend.service; + +import com.oracle.bmc.generativeai.GenerativeAiClient; +import com.oracle.bmc.generativeai.model.ModelCapability; +import com.oracle.bmc.generativeai.requests.ListModelsRequest; +import com.oracle.bmc.generativeai.responses.ListModelsResponse; +import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.stream.Collectors; + +@Service +public class GenAIModelsService { + Logger log = LoggerFactory.getLogger(GenAIModelsService.class); + + @Value("${genai.compartment_id}") + private String COMPARTMENT_ID; + + @Autowired + private GenAiClientService generativeAiClientService; + + public List getModels() { + log.info("getModels()"); + ListModelsRequest listModelsRequest = ListModelsRequest.builder() + .compartmentId(COMPARTMENT_ID) + .build(); + GenerativeAiClient client = generativeAiClientService.getClient(); + ListModelsResponse response = client.listModels(listModelsRequest); + return response.getModelCollection().getItems().stream() + .map(m -> { + List capabilities = m.getCapabilities().stream() + .map(ModelCapability::getValue).collect(Collectors.toList()); + GenAiModel model = new GenAiModel( + m.getId(), m.getDisplayName(), m.getVendor(), + m.getVersion(), capabilities, m.getTimeCreated()); + return model; + }).collect(Collectors.toList()); + } +} diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java index a7b1d2a8..ecf9d73c 100644 --- a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java @@ -1,48 +1,94 @@ package dev.victormartin.oci.genai.backend.backend.service; -import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; import com.oracle.bmc.generativeaiinference.model.*; +import com.oracle.bmc.generativeaiinference.model.Message; import com.oracle.bmc.generativeaiinference.requests.ChatRequest; -import com.oracle.bmc.generativeaiinference.requests.GenerateTextRequest; -import com.oracle.bmc.generativeaiinference.requests.SummarizeTextRequest; import com.oracle.bmc.generativeaiinference.responses.ChatResponse; -import com.oracle.bmc.generativeaiinference.responses.GenerateTextResponse; -import com.oracle.bmc.generativeaiinference.responses.SummarizeTextResponse; -import com.oracle.bmc.http.client.jersey.WrappedResponseInputStream; -import org.hibernate.boot.archive.scan.internal.StandardScanner; +import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; -import java.io.*; -import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; @Service public class OCIGenAIService { + + Logger log = LoggerFactory.getLogger(OCIGenAIService.class); + @Value("${genai.compartment_id}") private String COMPARTMENT_ID; @Autowired private GenAiInferenceClientService generativeAiInferenceClientService; - public String resolvePrompt(String input, String modelId, boolean finetune) { - CohereChatRequest cohereChatRequest = CohereChatRequest.builder() - .message(input) - .maxTokens(600) - .temperature((double) 1) - .frequencyPenalty((double) 0) - .topP((double) 0.75) - .topK(0) - .isStream(false) // TODO websockets and streams - .build(); + @Autowired + private GenAIModelsService genAIModelsService; - ChatDetails chatDetails = ChatDetails.builder() - .servingMode(OnDemandServingMode.builder().modelId(modelId).build()) - .compartmentId(COMPARTMENT_ID) - .chatRequest(cohereChatRequest) - .build(); + public String resolvePrompt(String input, String modelId, boolean finetune, boolean summarization) { + + List models = genAIModelsService.getModels(); + GenAiModel currentModel = models.stream() + .filter(m-> modelId.equals(m.id())) + .findFirst() + .orElseThrow(); + + log.info("Model {} with finetune {}", currentModel.name(), finetune? "yes" : "no"); + + double temperature = summarization?0.0:0.5; + + String inputText = summarization?"Summarize this text:\n" + input: input; + + ChatDetails chatDetails; + switch (currentModel.vendor()) { + case "cohere": + CohereChatRequest cohereChatRequest = CohereChatRequest.builder() + .message(inputText) + .maxTokens(600) + .temperature(temperature) + .frequencyPenalty((double) 0) + .topP(0.75) + .topK(0) + .isStream(false) // TODO websockets and streams + .build(); + + chatDetails = ChatDetails.builder() + .servingMode(OnDemandServingMode.builder().modelId(currentModel.id()).build()) + .compartmentId(COMPARTMENT_ID) + .chatRequest(cohereChatRequest) + .build(); + break; + case "meta": + ChatContent content = TextContent.builder() + .text(inputText) + .build(); + List contents = new ArrayList<>(); + contents.add(content); + List messages = new ArrayList<>(); + Message message = new UserMessage(contents, "user"); + messages.add(message); + GenericChatRequest genericChatRequest = GenericChatRequest.builder() + .messages(messages) + .maxTokens(600) + .temperature((double)1) + .frequencyPenalty((double)0) + .presencePenalty((double)0) + .topP(0.75) + .topK(-1) + .isStream(false) + .build(); + chatDetails = ChatDetails.builder() + .servingMode(OnDemandServingMode.builder().modelId(currentModel.id()).build()) + .compartmentId(COMPARTMENT_ID) + .chatRequest(genericChatRequest) + .build(); + break; + default: + throw new IllegalStateException("Unexpected value: " + currentModel.vendor()); + } ChatRequest request = ChatRequest.builder() .chatDetails(chatDetails) @@ -65,7 +111,7 @@ public String resolvePrompt(String input, String modelId, boolean finetune) { } public String summaryText(String input, String modelId, boolean finetuned) { - String response = resolvePrompt("Summarize this:\n" + input, modelId, finetuned); + String response = resolvePrompt(input, modelId, finetuned, true); return response; } }