Skip to content

Commit

Permalink
support meta models
Browse files Browse the repository at this point in the history
  • Loading branch information
vmleon committed Nov 11, 2024
1 parent ef53870 commit aeedf91
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 40 deletions.
5 changes: 2 additions & 3 deletions app/src/components/content/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ export const Settings = (props: Props) => {
const json = await response.json();
const result = json.filter((model: Model) => {
if (
model.capabilities.includes("TEXT_GENERATION") &&
(model.vendor == "cohere" || model.vendor == "") &&
model.version != "14.2"
model.capabilities.includes("CHAT") &&
(model.vendor == "cohere" || model.vendor == "meta")
)
return model;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.oracle.bmc.generativeai.responses.ListEndpointsResponse;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint;
import dev.victormartin.oci.genai.backend.backend.service.GenAIModelsService;
import dev.victormartin.oci.genai.backend.backend.service.GenAiClientService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -29,19 +30,16 @@ public class GenAIController {
@Autowired
private GenAiClientService generativeAiClientService;

@Autowired
private GenAIModelsService genAIModelsService;

@GetMapping("/api/genai/models")
public List<GenAiModel> getModels() {
logger.info("getModels()");
ListModelsRequest listModelsRequest = ListModelsRequest.builder().compartmentId(COMPARTMENT_ID).build();
GenerativeAiClient client = generativeAiClientService.getClient();
ListModelsResponse response = client.listModels(listModelsRequest);
return response.getModelCollection().getItems().stream().map(m -> {
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue)
.collect(Collectors.toList());
GenAiModel model = new GenAiModel(m.getId(), m.getDisplayName(), m.getVendor(), m.getVersion(),
capabilities, m.getTimeCreated());
return model;
}).collect(Collectors.toList());
List<GenAiModel> models = genAIModelsService.getModels();
return models.stream()
.filter(m -> m.capabilities().contains("CHAT"))
.collect(Collectors.toList());
}

@GetMapping("/api/genai/endpoints")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public Answer handlePrompt(Prompt prompt) {
throw new InvalidPromptRequest();
}
saved.setDatetimeResponse(new Date());
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune);
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune, false);
saved.setResponse(responseFromGenAI);
interactionRepository.save(saved);
return new Answer(responseFromGenAI, "");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package dev.victormartin.oci.genai.backend.backend.service;

import com.oracle.bmc.generativeai.GenerativeAiClient;
import com.oracle.bmc.generativeai.model.ModelCapability;
import com.oracle.bmc.generativeai.requests.ListModelsRequest;
import com.oracle.bmc.generativeai.responses.ListModelsResponse;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.stream.Collectors;

@Service
public class GenAIModelsService {
Logger log = LoggerFactory.getLogger(GenAIModelsService.class);

@Value("${genai.compartment_id}")
private String COMPARTMENT_ID;

@Autowired
private GenAiClientService generativeAiClientService;

public List<GenAiModel> getModels() {
log.info("getModels()");
ListModelsRequest listModelsRequest = ListModelsRequest.builder()
.compartmentId(COMPARTMENT_ID)
.build();
GenerativeAiClient client = generativeAiClientService.getClient();
ListModelsResponse response = client.listModels(listModelsRequest);
return response.getModelCollection().getItems().stream()
.map(m -> {
List<String> capabilities = m.getCapabilities().stream()
.map(ModelCapability::getValue).collect(Collectors.toList());
GenAiModel model = new GenAiModel(
m.getId(), m.getDisplayName(), m.getVendor(),
m.getVersion(), capabilities, m.getTimeCreated());
return model;
}).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -1,48 +1,94 @@
package dev.victormartin.oci.genai.backend.backend.service;

import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient;
import com.oracle.bmc.generativeaiinference.model.*;
import com.oracle.bmc.generativeaiinference.model.Message;
import com.oracle.bmc.generativeaiinference.requests.ChatRequest;
import com.oracle.bmc.generativeaiinference.requests.GenerateTextRequest;
import com.oracle.bmc.generativeaiinference.requests.SummarizeTextRequest;
import com.oracle.bmc.generativeaiinference.responses.ChatResponse;
import com.oracle.bmc.generativeaiinference.responses.GenerateTextResponse;
import com.oracle.bmc.generativeaiinference.responses.SummarizeTextResponse;
import com.oracle.bmc.http.client.jersey.WrappedResponseInputStream;
import org.hibernate.boot.archive.scan.internal.StandardScanner;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

@Service
public class OCIGenAIService {

Logger log = LoggerFactory.getLogger(OCIGenAIService.class);

@Value("${genai.compartment_id}")
private String COMPARTMENT_ID;

@Autowired
private GenAiInferenceClientService generativeAiInferenceClientService;

public String resolvePrompt(String input, String modelId, boolean finetune) {
CohereChatRequest cohereChatRequest = CohereChatRequest.builder()
.message(input)
.maxTokens(600)
.temperature((double) 1)
.frequencyPenalty((double) 0)
.topP((double) 0.75)
.topK(0)
.isStream(false) // TODO websockets and streams
.build();
@Autowired
private GenAIModelsService genAIModelsService;

ChatDetails chatDetails = ChatDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.chatRequest(cohereChatRequest)
.build();
public String resolvePrompt(String input, String modelId, boolean finetune, boolean summarization) {

List<GenAiModel> models = genAIModelsService.getModels();
GenAiModel currentModel = models.stream()
.filter(m-> modelId.equals(m.id()))
.findFirst()
.orElseThrow();

log.info("Model {} with finetune {}", currentModel.name(), finetune? "yes" : "no");

double temperature = summarization?0.0:0.5;

String inputText = summarization?"Summarize this text:\n" + input: input;

ChatDetails chatDetails;
switch (currentModel.vendor()) {
case "cohere":
CohereChatRequest cohereChatRequest = CohereChatRequest.builder()
.message(inputText)
.maxTokens(600)
.temperature(temperature)
.frequencyPenalty((double) 0)
.topP(0.75)
.topK(0)
.isStream(false) // TODO websockets and streams
.build();

chatDetails = ChatDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(currentModel.id()).build())
.compartmentId(COMPARTMENT_ID)
.chatRequest(cohereChatRequest)
.build();
break;
case "meta":
ChatContent content = TextContent.builder()
.text(inputText)
.build();
List<ChatContent> contents = new ArrayList<>();
contents.add(content);
List<Message> messages = new ArrayList<>();
Message message = new UserMessage(contents, "user");
messages.add(message);
GenericChatRequest genericChatRequest = GenericChatRequest.builder()
.messages(messages)
.maxTokens(600)
.temperature((double)1)
.frequencyPenalty((double)0)
.presencePenalty((double)0)
.topP(0.75)
.topK(-1)
.isStream(false)
.build();
chatDetails = ChatDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(currentModel.id()).build())
.compartmentId(COMPARTMENT_ID)
.chatRequest(genericChatRequest)
.build();
break;
default:
throw new IllegalStateException("Unexpected value: " + currentModel.vendor());
}

ChatRequest request = ChatRequest.builder()
.chatDetails(chatDetails)
Expand All @@ -65,7 +111,7 @@ public String resolvePrompt(String input, String modelId, boolean finetune) {
}

public String summaryText(String input, String modelId, boolean finetuned) {
String response = resolvePrompt("Summarize this:\n" + input, modelId, finetuned);
String response = resolvePrompt(input, modelId, finetuned, true);
return response;
}
}

0 comments on commit aeedf91

Please sign in to comment.