Skip to content

Commit

Permalink
Merge pull request #38 from johnoliver/streaming-3
Browse files Browse the repository at this point in the history
Add streaming capability
  • Loading branch information
dantelmomsft authored Oct 19, 2023
2 parents fc733e6 + 86c5ed5 commit 4296b62
Show file tree
Hide file tree
Showing 18 changed files with 380 additions and 66 deletions.
11 changes: 11 additions & 0 deletions app/backend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@
<artifactId>spring-boot-starter-web</artifactId>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>compile</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@ public static void main(String[] args) {
LOG.info("Application profile from system property is [{}]", System.getProperty("spring.profiles.active"));
new SpringApplication(Application.class).run(args);
}

}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package com.microsoft.openai.samples.rag.approaches;

public interface RAGApproach<I, O> {

O run(I questionOrConversation, RAGOptions options);


import com.microsoft.openai.samples.rag.common.ChatGPTConversation;
import reactor.core.publisher.Flux;

import java.io.OutputStream;

public interface RAGApproach<I, O> {

O run(I questionOrConversation, RAGOptions options);
void runStreaming(I questionOrConversation, RAGOptions options, OutputStream outputStream);
}
Original file line number Diff line number Diff line change
@@ -1,35 +1,43 @@
package com.microsoft.openai.samples.rag.ask.approaches;

import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.core.util.IterableStream;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.openai.samples.rag.approaches.ContentSource;
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider;
import com.microsoft.openai.samples.rag.controller.ChatResponse;
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider;
import com.microsoft.openai.samples.rag.retrieval.Retriever;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.io.OutputStream;
import java.util.List;

/**
* Simple retrieve-then-read java implementation, using the Cognitive Search and OpenAI APIs directly. It first retrieves
* top documents from search, then constructs a prompt with them, and then uses OpenAI to generate a completion
* (answer) with that prompt.
* top documents from search, then constructs a prompt with them, and then uses OpenAI to generate a completion
* (answer) with that prompt.
*/
@Component
public class PlainJavaAskApproach implements RAGApproach<String, RAGResponse> {

private static final Logger LOGGER = LoggerFactory.getLogger(PlainJavaAskApproach.class);
private final OpenAIProxy openAIProxy;
private final FactsRetrieverProvider factsRetrieverProvider;
private final ObjectMapper objectMapper;

public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenAIProxy openAIProxy) {
public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenAIProxy openAIProxy, ObjectMapper objectMapper) {
this.factsRetrieverProvider = factsRetrieverProvider;
this.openAIProxy = openAIProxy;
this.objectMapper = objectMapper;
}

/**
Expand All @@ -39,8 +47,6 @@ public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenA
*/
@Override
public RAGResponse run(String question, RAGOptions options) {
//TODO exception handling

//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options);
List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options);
Expand All @@ -51,14 +57,14 @@ public RAGResponse run(String question, RAGOptions options) {
var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty());

//true will replace the default prompt. False will add custom prompt as suffix to the default prompt
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
if(!replacePrompt && !customPromptEmpty){
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
if (!replacePrompt && !customPromptEmpty) {
customPrompt = customPrompt.substring(1);
}

var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt);

var groundedChatMessages = answerQuestionChatTemplate.getMessages(question,sources);
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources);
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages);

// STEP 3: Generate a contextual and content specific answer using the retrieve facts
Expand All @@ -75,8 +81,67 @@ public RAGResponse run(String question, RAGOptions options) {
.answer(chatCompletions.getChoices().get(0).getMessage().getContent())
.sources(sources)
.build();

}

@Override
public void runStreaming(String question, RAGOptions options, OutputStream outputStream) {
//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options);
List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options);
LOGGER.info("Total {} sources found in cognitive search for keyword search query[{}]", sources.size(),
question);

var customPrompt = options.getPromptTemplate();
var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty());

//true will replace the default prompt. False will add custom prompt as suffix to the default prompt
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
if (!replacePrompt && !customPromptEmpty) {
customPrompt = customPrompt.substring(1);
}

var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt);

var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources);
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages);

IterableStream<ChatCompletions> completions = openAIProxy.getChatCompletionsStream(chatCompletionsOptions);
int index = 0;
for (ChatCompletions completion : completions) {

LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
completion.getUsage().getPromptTokens(),
completion.getUsage().getCompletionTokens(),
completion.getUsage().getTotalTokens());

for (ChatChoice choice : completion.getChoices()) {
if (choice.getDelta().getContent() == null) {
continue;
}

RAGResponse ragResponse = new RAGResponse.Builder()
.question(question)
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
.answer(choice.getMessage().getContent())
.sources(sources)
.build();

ChatResponse response;
if (index == 0) {
response = ChatResponse.buildChatResponse(ragResponse);
} else {
response = ChatResponse.buildChatDeltaResponse(index, ragResponse);
}
index++;

try {
String value = objectMapper.writeValueAsString(response) + "\n";
outputStream.write(value.getBytes());
outputStream.flush();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

import java.io.OutputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -81,6 +83,11 @@ public RAGResponse run(String question, RAGOptions options) {

}

@Override
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
throw new IllegalStateException("Streaming not supported for this approach");
}

private List<ContentSource> formSourcesList(String result) {
if (result == null) {
return Collections.emptyList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

import java.io.OutputStream;
import java.util.Objects;
import java.util.Set;

Expand Down Expand Up @@ -80,6 +82,11 @@ public RAGResponse run(String question, RAGOptions options) {

}

@Override
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
throw new IllegalStateException("Streaming not supported for this approach");
}

private Kernel buildSemanticKernel( RAGOptions options) {
Kernel kernel = SKBuilders.kernel()
.withDefaultAIService(SKBuilders.chatCompletion()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.OutputStream;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* Accomplish the same task as in the PlainJavaAskApproach approach but using Semantic Kernel framework:
* 1. Memory abstraction is used for vector search capability. It uses Azure Cognitive Search as memory store.
* 2. Semantic function has been defined to ask question using sources from memory search results
* Accomplish the same task as in the PlainJavaAskApproach approach but using Semantic Kernel framework:
* 1. Memory abstraction is used for vector search capability. It uses Azure Cognitive Search as memory store.
* 2. Semantic function has been defined to ask question using sources from memory search results
*/
@Component
public class JavaSemanticKernelWithMemoryApproach implements RAGApproach<String, RAGResponse> {
Expand All @@ -40,8 +42,10 @@ public class JavaSemanticKernelWithMemoryApproach implements RAGApproach<String,

private final String EMBEDDING_FIELD_NAME = "embedding";

@Value("${cognitive.search.service}") String searchServiceName ;
@Value("${cognitive.search.index}") String indexName;
@Value("${cognitive.search.service}")
String searchServiceName;
@Value("${cognitive.search.index}")
String indexName;
@Value("${openai.chatgpt.deployment}")
private String gptChatDeploymentModelId;

Expand Down Expand Up @@ -70,11 +74,11 @@ public RAGResponse run(String question, RAGOptions options) {
* Question embeddings are provided to cognitive search via search options.
*/
List<MemoryQueryResult> memoryResult = semanticKernel.getMemory().searchAsync(
indexName,
question,
options.getTop(),
0.5f,
false)
indexName,
question,
options.getTop(),
0.5f,
false)
.block();

LOGGER.info("Total {} sources found in cognitive vector store for search query[{}]", memoryResult.size(), question);
Expand All @@ -90,14 +94,19 @@ public RAGResponse run(String question, RAGOptions options) {
Mono<SKContext> result = semanticKernel.getFunction("RAG", "AnswerQuestion").invokeAsync(skcontext);

return new RAGResponse.Builder()
//.prompt(plan.toPlanString())
.prompt("placeholders for prompt")
.answer(result.block().getResult())
.sources(sourcesList)
.sourcesAsText(sources)
.question(question)
.build();
//.prompt(plan.toPlanString())
.prompt("placeholders for prompt")
.answer(result.block().getResult())
.sources(sourcesList)
.sourcesAsText(sources)
.question(question)
.build();

}

@Override
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
throw new IllegalStateException("Streaming not supported for this approach");
}

private List<ContentSource> buildSources(List<MemoryQueryResult> memoryResult) {
Expand All @@ -123,15 +132,14 @@ private String buildSourcesText(List<MemoryQueryResult> memoryResult) {
return sourcesContentBuffer.toString();
}

private Kernel buildSemanticKernel( RAGOptions options) {

private Kernel buildSemanticKernel(RAGOptions options) {
var kernelWithACS = SKBuilders.kernel()
.withMemoryStorage(
new CustomAzureCognitiveSearchMemoryStore("https://%s.search.windows.net".formatted(searchServiceName),
tokenCredential,
this.searchAsyncClient,
this.EMBEDDING_FIELD_NAME,
buildCustomMemoryMapper()))
tokenCredential,
this.searchAsyncClient,
this.EMBEDDING_FIELD_NAME,
buildCustomMemoryMapper()))
.withDefaultAIService(SKBuilders.textEmbeddingGeneration()
.withOpenAIClient(openAIAsyncClient)
.withModelId(embeddingDeploymentModelId)
Expand All @@ -142,14 +150,13 @@ private Kernel buildSemanticKernel( RAGOptions options) {
.build())
.build();

kernelWithACS.importSkillFromResources("semantickernel/Plugins","RAG","AnswerQuestion",null);
return kernelWithACS;
kernelWithACS.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerQuestion", null);
return kernelWithACS;
}


private Function<SearchDocument, MemoryRecord> buildCustomMemoryMapper(){
private Function<SearchDocument, MemoryRecord> buildCustomMemoryMapper() {
return searchDocument -> {
return MemoryRecord.localRecord(
return MemoryRecord.localRecord(
(String) searchDocument.get("sourcepage"),
(String) searchDocument.get("content"),
"chunked text from original source",
Expand Down
Loading

0 comments on commit 4296b62

Please sign in to comment.