From ef534b5a65dac2d54a797e446a291c210055baf4 Mon Sep 17 00:00:00 2001 From: John Oliver <1615532+johnoliver@users.noreply.github.com> Date: Fri, 20 Oct 2023 12:11:10 +0100 Subject: [PATCH] Apply formatting rules --- .../openai/samples/rag/Application.java | 5 +- .../samples/rag/approaches/ContentSource.java | 90 ++--- .../rag/approaches/PromptTemplate.java | 18 +- .../samples/rag/approaches/RAGApproach.java | 23 +- .../rag/approaches/RAGApproachFactory.java | 2 +- .../RAGApproachFactorySpringBootImpl.java | 27 +- .../samples/rag/approaches/RAGOptions.java | 226 +++++------ .../samples/rag/approaches/RAGResponse.java | 157 ++++---- .../samples/rag/approaches/RAGType.java | 13 +- .../samples/rag/approaches/RetrievalMode.java | 14 +- .../rag/approaches/SemanticKernelMode.java | 13 +- .../AnswerQuestionChatTemplate.java | 221 ++++++----- .../ask/approaches/PlainJavaAskApproach.java | 327 ++++++++-------- .../semantickernel/CognitiveSearchPlugin.java | 118 +++--- .../JavaSemanticKernelChainsApproach.java | 316 +++++++-------- .../JavaSemanticKernelPlannerApproach.java | 248 ++++++------ .../JavaSemanticKernelWithMemoryApproach.java | 368 ++++++++++-------- ...CustomAzureCognitiveSearchMemoryStore.java | 180 +++++---- .../rag/ask/controller/AskController.java | 119 +++--- .../approaches/PlainJavaChatApproach.java | 313 ++++++++------- .../chat/approaches/SemanticSearchChat.java | 252 ++++++------ .../rag/chat/controller/ChatController.java | 132 ++++--- .../rag/common/ChatGPTConversation.java | 72 ++-- .../samples/rag/common/ChatGPTMessage.java | 50 ++- .../samples/rag/common/ChatGPTUtils.java | 102 ++--- .../AzureAuthenticationConfiguration.java | 50 +-- .../config/CognitiveSearchConfiguration.java | 168 ++++---- .../rag/config/OpenAIConfiguration.java | 157 ++++---- .../content/controller/ContentController.java | 17 +- .../rag/controller/ChatAppRequest.java | 5 +- .../rag/controller/ChatAppRequestContext.java | 4 +- .../controller/ChatAppRequestOverrides.java | 5 +- .../samples/rag/controller/ChatResponse.java | 115 +++--- .../rag/controller/ResponseChoice.java | 5 +- .../rag/controller/ResponseContext.java | 4 +- .../rag/controller/ResponseMessage.java | 4 +- .../rag/controller/auth/AuthSetup.java | 4 +- .../samples/rag/proxy/BlobStorageProxy.java | 95 +++-- .../rag/proxy/CognitiveSearchProxy.java | 58 ++- .../openai/samples/rag/proxy/OpenAIProxy.java | 193 ++++----- .../retrieval/CognitiveSearchRetriever.java | 352 +++++++++-------- .../ExtractKeywordsChatTemplate.java | 158 ++++---- .../rag/retrieval/FactsRetrieverProvider.java | 65 ++-- .../samples/rag/retrieval/Retriever.java | 29 +- .../openai/samples/rag/AskAPITest.java | 115 +++--- .../openai/samples/rag/ChatAPITest.java | 58 ++- .../RAGApproachFactorySpringBootImplTest.java | 47 ++- .../test/config/ProxyMockConfiguration.java | 82 ++-- .../utils/CognitiveSearchUnitTestUtils.java | 132 ++++--- .../rag/test/utils/OpenAIUnitTestUtils.java | 17 +- 50 files changed, 2793 insertions(+), 2552 deletions(-) diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/Application.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/Application.java index 775c6dd..50fefcd 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/Application.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/Application.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag; import org.slf4j.Logger; @@ -11,7 +12,9 @@ public class Application { private static final Logger LOG = LoggerFactory.getLogger(Application.class); public static void main(String[] args) { - LOG.info("Application profile from system property is [{}]", System.getProperty("spring.profiles.active")); + LOG.info( + "Application profile from system property is [{}]", + System.getProperty("spring.profiles.active")); new SpringApplication(Application.class).run(args); } } diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/ContentSource.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/ContentSource.java index e173a30..2c10480 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/ContentSource.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/ContentSource.java @@ -1,45 +1,45 @@ -package com.microsoft.openai.samples.rag.approaches; - -public class ContentSource { - - private String sourceName; - private String sourceContent; - private final boolean noNewLine; - - public ContentSource(String sourceName, String sourceContent, Boolean noNewLine) { - this.noNewLine = noNewLine; - this.sourceName = sourceName; - buildContent(sourceContent); - } - - public ContentSource(String sourceName, String sourceContent) { - this(sourceName, sourceContent, true); - } - - public String getSourceName() { - return sourceName; - } - - public void setSourceName(String sourceName) { - this.sourceName = sourceName; - } - - public String getSourceContent() { - return sourceContent; - } - - public void setSourceContent(String sourceContent) { - this.sourceContent = sourceContent; - } - - public boolean isNoNewLine() { - return noNewLine; - } - - private void buildContent(String sourceContent) { - if (this.noNewLine) { - this.sourceContent = sourceContent.replace("\n", ""); - } - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.approaches; + +public class ContentSource { + + private String sourceName; + private String sourceContent; + private final boolean noNewLine; + + public ContentSource(String sourceName, String sourceContent, Boolean noNewLine) { + this.noNewLine = noNewLine; + this.sourceName = sourceName; + buildContent(sourceContent); + } + + public ContentSource(String sourceName, String sourceContent) { + this(sourceName, sourceContent, true); + } + + public String getSourceName() { + return sourceName; + } + + public void setSourceName(String sourceName) { + this.sourceName = sourceName; + } + + public String getSourceContent() { + return sourceContent; + } + + public void setSourceContent(String sourceContent) { + this.sourceContent = sourceContent; + } + + public boolean isNoNewLine() { + return noNewLine; + } + + private void buildContent(String sourceContent) { + if (this.noNewLine) { + this.sourceContent = sourceContent.replace("\n", ""); + } + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/PromptTemplate.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/PromptTemplate.java index 5056dae..c4b2a88 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/PromptTemplate.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/PromptTemplate.java @@ -1,9 +1,9 @@ -package com.microsoft.openai.samples.rag.approaches; - -public interface PromptTemplate { - - String getPrompt(); - - void setVariables(); - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.approaches; + +public interface PromptTemplate { + + String getPrompt(); + + void setVariables(); +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproach.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproach.java index 2e3210b..46edf0b 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproach.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproach.java @@ -1,12 +1,11 @@ -package com.microsoft.openai.samples.rag.approaches; - -import com.microsoft.openai.samples.rag.common.ChatGPTConversation; -import reactor.core.publisher.Flux; - -import java.io.OutputStream; - -public interface RAGApproach { - - O run(I questionOrConversation, RAGOptions options); - void runStreaming(I questionOrConversation, RAGOptions options, OutputStream outputStream); -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.approaches; + +import java.io.OutputStream; + +public interface RAGApproach { + + O run(I questionOrConversation, RAGOptions options); + + void runStreaming(I questionOrConversation, RAGOptions options, OutputStream outputStream); +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactory.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactory.java index c43fdc8..fdabbf6 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactory.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactory.java @@ -1,7 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.approaches; public interface RAGApproachFactory { RAGApproach createApproach(String approachName, RAGType ragType, RAGOptions options); - } diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImpl.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImpl.java index 5e9dd14..a94b01b 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImpl.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImpl.java @@ -1,16 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.approaches; import com.microsoft.openai.samples.rag.ask.approaches.PlainJavaAskApproach; import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelChainsApproach; -import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelWithMemoryApproach; import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelPlannerApproach; +import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelWithMemoryApproach; import com.microsoft.openai.samples.rag.chat.approaches.PlainJavaChatApproach; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.stereotype.Component; @Component -public class RAGApproachFactorySpringBootImpl implements RAGApproachFactory, ApplicationContextAware { +public class RAGApproachFactorySpringBootImpl + implements RAGApproachFactory, ApplicationContextAware { private static final String JAVA_OPENAI_SDK = "jos"; private static final String JAVA_SEMANTIC_KERNEL = "jsk"; @@ -35,18 +37,23 @@ public RAGApproach createApproach(String approachName, RAGType ragType, RAGOptio return applicationContext.getBean(PlainJavaAskApproach.class); else if (JAVA_SEMANTIC_KERNEL.equals(approachName)) return applicationContext.getBean(JavaSemanticKernelWithMemoryApproach.class); - else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) && ragOptions.getSemantickKernelMode() != null && ragOptions.getSemantickKernelMode() == SemanticKernelMode.planner) - return applicationContext.getBean(JavaSemanticKernelPlannerApproach.class); - else if(JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) && ragOptions != null && ragOptions.getSemantickKernelMode() != null && ragOptions.getSemantickKernelMode() == SemanticKernelMode.chains) - return applicationContext.getBean(JavaSemanticKernelChainsApproach.class); - + else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) + && ragOptions.getSemantickKernelMode() != null + && ragOptions.getSemantickKernelMode() == SemanticKernelMode.planner) + return applicationContext.getBean(JavaSemanticKernelPlannerApproach.class); + else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) + && ragOptions != null + && ragOptions.getSemantickKernelMode() != null + && ragOptions.getSemantickKernelMode() == SemanticKernelMode.chains) + return applicationContext.getBean(JavaSemanticKernelChainsApproach.class); } - //if this point is reached then the combination of approach and rag type is not supported - throw new IllegalArgumentException("Invalid combination for approach[%s] and rag type[%s]: ".formatted(approachName, ragType)); + // if this point is reached then the combination of approach and rag type is not supported + throw new IllegalArgumentException( + "Invalid combination for approach[%s] and rag type[%s]: " + .formatted(approachName, ragType)); } public void setApplicationContext(ApplicationContext applicationContext) { this.applicationContext = applicationContext; } - } diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGOptions.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGOptions.java index b04ed7d..7eae502 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGOptions.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGOptions.java @@ -1,112 +1,114 @@ -package com.microsoft.openai.samples.rag.approaches; - -public class RAGOptions { - - private RetrievalMode retrievalMode; - private SemanticKernelMode semantickKernelMode; - private boolean semanticRanker; - private boolean semanticCaptions; - private boolean suggestFollowupQuestions; - private String excludeCategory; - private String promptTemplate; - private Integer top; - - private RAGOptions() { - } - - public RetrievalMode getRetrievalMode() { - return retrievalMode; - } - public SemanticKernelMode getSemantickKernelMode() { - return semantickKernelMode; - } - public boolean isSemanticRanker() { - return semanticRanker; - } - - public boolean isSemanticCaptions() { - return semanticCaptions; - } - - public String getExcludeCategory() { - return excludeCategory; - } - - public String getPromptTemplate() { - return promptTemplate; - } - - public Integer getTop() { - return top; - } - - public boolean isSuggestFollowupQuestions() { - return suggestFollowupQuestions; - } - - public static class Builder { - private RetrievalMode retrievalMode; - - private SemanticKernelMode semanticKernelMode; - private boolean semanticRanker; - private boolean semanticCaptions; - private String excludeCategory; - private String promptTemplate; - private Integer top; - - private boolean suggestFollowupQuestions; - - public Builder retrievialMode(String retrievialMode) { - this.retrievalMode = RetrievalMode.valueOf(retrievialMode); - return this; - } - public Builder semanticKernelMode(String semanticKernelMode) { - this.semanticKernelMode = SemanticKernelMode.valueOf(semanticKernelMode); - return this; - } - public Builder semanticRanker(boolean semanticRanker) { - this.semanticRanker = semanticRanker; - return this; - } - - public Builder semanticCaptions(boolean semanticCaptions) { - this.semanticCaptions = semanticCaptions; - return this; - } - - public Builder suggestFollowupQuestions(boolean suggestFollowupQuestions) { - this.suggestFollowupQuestions = suggestFollowupQuestions; - return this; - } - - public Builder excludeCategory(String excludeCategory) { - this.excludeCategory = excludeCategory; - return this; - } - - public Builder promptTemplate(String promptTemplate) { - this.promptTemplate = promptTemplate; - return this; - } - - public Builder top(Integer top) { - this.top = top; - return this; - } - - - public RAGOptions build() { - RAGOptions ragOptions = new RAGOptions(); - ragOptions.retrievalMode = this.retrievalMode; - ragOptions.semantickKernelMode = this.semanticKernelMode; - ragOptions.semanticRanker = this.semanticRanker; - ragOptions.semanticCaptions = this.semanticCaptions; - ragOptions.suggestFollowupQuestions = this.suggestFollowupQuestions; - ragOptions.excludeCategory = this.excludeCategory; - ragOptions.promptTemplate = this.promptTemplate; - ragOptions.top = this.top; - return ragOptions; - } - } - -} \ No newline at end of file +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.approaches; + +public class RAGOptions { + + private RetrievalMode retrievalMode; + private SemanticKernelMode semantickKernelMode; + private boolean semanticRanker; + private boolean semanticCaptions; + private boolean suggestFollowupQuestions; + private String excludeCategory; + private String promptTemplate; + private Integer top; + + private RAGOptions() {} + + public RetrievalMode getRetrievalMode() { + return retrievalMode; + } + + public SemanticKernelMode getSemantickKernelMode() { + return semantickKernelMode; + } + + public boolean isSemanticRanker() { + return semanticRanker; + } + + public boolean isSemanticCaptions() { + return semanticCaptions; + } + + public String getExcludeCategory() { + return excludeCategory; + } + + public String getPromptTemplate() { + return promptTemplate; + } + + public Integer getTop() { + return top; + } + + public boolean isSuggestFollowupQuestions() { + return suggestFollowupQuestions; + } + + public static class Builder { + private RetrievalMode retrievalMode; + + private SemanticKernelMode semanticKernelMode; + private boolean semanticRanker; + private boolean semanticCaptions; + private String excludeCategory; + private String promptTemplate; + private Integer top; + + private boolean suggestFollowupQuestions; + + public Builder retrievialMode(String retrievialMode) { + this.retrievalMode = RetrievalMode.valueOf(retrievialMode); + return this; + } + + public Builder semanticKernelMode(String semanticKernelMode) { + this.semanticKernelMode = SemanticKernelMode.valueOf(semanticKernelMode); + return this; + } + + public Builder semanticRanker(boolean semanticRanker) { + this.semanticRanker = semanticRanker; + return this; + } + + public Builder semanticCaptions(boolean semanticCaptions) { + this.semanticCaptions = semanticCaptions; + return this; + } + + public Builder suggestFollowupQuestions(boolean suggestFollowupQuestions) { + this.suggestFollowupQuestions = suggestFollowupQuestions; + return this; + } + + public Builder excludeCategory(String excludeCategory) { + this.excludeCategory = excludeCategory; + return this; + } + + public Builder promptTemplate(String promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public Builder top(Integer top) { + this.top = top; + return this; + } + + public RAGOptions build() { + RAGOptions ragOptions = new RAGOptions(); + ragOptions.retrievalMode = this.retrievalMode; + ragOptions.semantickKernelMode = this.semanticKernelMode; + ragOptions.semanticRanker = this.semanticRanker; + ragOptions.semanticCaptions = this.semanticCaptions; + ragOptions.suggestFollowupQuestions = this.suggestFollowupQuestions; + ragOptions.excludeCategory = this.excludeCategory; + ragOptions.promptTemplate = this.promptTemplate; + ragOptions.top = this.top; + return ragOptions; + } + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGResponse.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGResponse.java index 614b270..f2c6e47 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGResponse.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGResponse.java @@ -1,79 +1,78 @@ -package com.microsoft.openai.samples.rag.approaches; - -import java.util.List; - -public class RAGResponse { - - private final String question; - private final List sources; - private final String sourcesAsText; - private final String answer; - private final String prompt; - - private RAGResponse(Builder builder) { - this.question = builder.question; - this.sources = builder.sources; - this.answer = builder.answer; - this.prompt = builder.prompt; - this.sourcesAsText = builder.sourcesAsText; - } - - public String getQuestion() { - return question; - } - - public List getSources() { - return sources; - } - - public String getSourcesAsText() { - return sourcesAsText; - } - - public String getAnswer() { - return answer; - } - - public String getPrompt() { - return prompt; - } - - public static class Builder { - private String question; - private List sources; - private String sourcesAsText; - private String answer; - private String prompt; - - - public Builder question(String question) { - this.question = question; - return this; - } - - public Builder sources(List sources) { - this.sources = sources; - return this; - } - - public Builder sourcesAsText(String sourcesAsText) { - this.sourcesAsText = sourcesAsText; - return this; - } - - public Builder answer(String answer) { - this.answer = answer; - return this; - } - - public Builder prompt(String prompt) { - this.prompt = prompt; - return this; - } - - public RAGResponse build() { - return new RAGResponse(this); - } - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.approaches; + +import java.util.List; + +public class RAGResponse { + + private final String question; + private final List sources; + private final String sourcesAsText; + private final String answer; + private final String prompt; + + private RAGResponse(Builder builder) { + this.question = builder.question; + this.sources = builder.sources; + this.answer = builder.answer; + this.prompt = builder.prompt; + this.sourcesAsText = builder.sourcesAsText; + } + + public String getQuestion() { + return question; + } + + public List getSources() { + return sources; + } + + public String getSourcesAsText() { + return sourcesAsText; + } + + public String getAnswer() { + return answer; + } + + public String getPrompt() { + return prompt; + } + + public static class Builder { + private String question; + private List sources; + private String sourcesAsText; + private String answer; + private String prompt; + + public Builder question(String question) { + this.question = question; + return this; + } + + public Builder sources(List sources) { + this.sources = sources; + return this; + } + + public Builder sourcesAsText(String sourcesAsText) { + this.sourcesAsText = sourcesAsText; + return this; + } + + public Builder answer(String answer) { + this.answer = answer; + return this; + } + + public Builder prompt(String prompt) { + this.prompt = prompt; + return this; + } + + public RAGResponse build() { + return new RAGResponse(this); + } + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGType.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGType.java index c56eda3..ac74168 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGType.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGType.java @@ -1,6 +1,7 @@ -package com.microsoft.openai.samples.rag.approaches; - -public enum RAGType { - CHAT, ASK - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.approaches; + +public enum RAGType { + CHAT, + ASK +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RetrievalMode.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RetrievalMode.java index 82f20b4..66df143 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RetrievalMode.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RetrievalMode.java @@ -1,6 +1,8 @@ -package com.microsoft.openai.samples.rag.approaches; - -public enum RetrievalMode { - hybrid, vectors, text; - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.approaches; + +public enum RetrievalMode { + hybrid, + vectors, + text; +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/SemanticKernelMode.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/SemanticKernelMode.java index c1c22e7..0192d3c 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/SemanticKernelMode.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/SemanticKernelMode.java @@ -1,6 +1,7 @@ -package com.microsoft.openai.samples.rag.approaches; - -public enum SemanticKernelMode { - chains, planner; - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.approaches; + +public enum SemanticKernelMode { + chains, + planner; +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/AnswerQuestionChatTemplate.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/AnswerQuestionChatTemplate.java index eac6830..3e73bf7 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/AnswerQuestionChatTemplate.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/AnswerQuestionChatTemplate.java @@ -1,105 +1,116 @@ -package com.microsoft.openai.samples.rag.ask.approaches; - -import com.azure.ai.openai.models.ChatMessage; -import com.azure.ai.openai.models.ChatRole; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import java.util.ArrayList; -import java.util.List; - -public class AnswerQuestionChatTemplate { - - private final List conversationHistory = new ArrayList<>(); - - private String customPrompt = ""; - private String systemMessage; - private Boolean replacePrompt = false; - - private static final String SYSTEM_CHAT_MESSAGE_TEMPLATE = """ - You are an intelligent assistant helping Contoso Inc employees with their healthcare plan questions and employee handbook questions. - Use 'you' to refer to the individual asking the questions even if they ask with 'I'. - Answer the user question using only the data provided by the user in his message. - For tabular information return it as an html table. Do not return markdown format. - Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. - If you cannot answer say you don't know. - %s - """ ; - - private static final String FEW_SHOT_USER_MESSAGE = """ - What is the deductible for the employee plan for a visit to Overlake in Bellevue?' - Sources: - info1.txt: deductibles depend on whether you are in-network or out-of-network. In-network deductibles are $500 for employee and $1000 for family. Out-of-network deductibles are $1000 for employee and $2000 for family. - info2.pdf: Overlake is in-network for the employee plan. - info3.pdf: Overlake is the name of the area that includes a park and ride near Bellevue. - info4.pdf: In-network institutions include Overlake, Swedish and others in the region - """; - private static final String FEW_SHOT_ASSISTANT_MESSAGE = """ - In-network deductibles are $500 for employee and $1000 for family [info1.txt] and Overlake is in-network for the employee plan [info2.pdf][info4.pdf]. - """; - - /** - * - * @param conversation conversation history - * @param sources domain specific sources to be used in the prompt - * @param customPrompt custom prompt to be injected in the existing promptTemplate or used to replace it - * @param replacePrompt if true, the customPrompt will replace the default promptTemplate, otherwise it will be appended - * to the default promptTemplate in the predefined section - */ - - private static final String GROUNDED_USER_QUESTION_TEMPLATE = """ - %s - Sources: - %s - """; - public AnswerQuestionChatTemplate( String customPrompt, Boolean replacePrompt) { - - if(replacePrompt && (customPrompt == null || customPrompt.isEmpty())) - throw new IllegalStateException("customPrompt cannot be null or empty when replacePrompt is true"); - - this.replacePrompt = replacePrompt; - this.customPrompt = customPrompt == null ? "" : customPrompt; - - - if(this.replacePrompt){ - this.systemMessage = customPrompt; - } else { - this.systemMessage = SYSTEM_CHAT_MESSAGE_TEMPLATE.formatted(this.customPrompt); - } - - //Add system message - ChatMessage chatSystemMessage = new ChatMessage(ChatRole.SYSTEM); - chatSystemMessage.setContent(systemMessage); - - this.conversationHistory.add(chatSystemMessage); - - //Add few shoot learning with chat - ChatMessage fewShotUserMessage = new ChatMessage(ChatRole.USER); - fewShotUserMessage.setContent(FEW_SHOT_USER_MESSAGE); - this.conversationHistory.add(fewShotUserMessage); - - ChatMessage fewShotAssistantMessage = new ChatMessage(ChatRole.ASSISTANT); - fewShotAssistantMessage.setContent(FEW_SHOT_ASSISTANT_MESSAGE); - this.conversationHistory.add(fewShotAssistantMessage); - } - - - public List getMessages(String question,List sources ) { - if (sources == null || sources.isEmpty()) - throw new IllegalStateException("sources cannot be null or empty"); - if (question == null || question.isEmpty()) - throw new IllegalStateException("question cannot be null"); - - StringBuilder sourcesStringBuilder = new StringBuilder(); - // Build sources section - sources.iterator().forEachRemaining(source -> sourcesStringBuilder.append(source.getSourceName()).append(": ").append(source.getSourceContent()).append("\n")); - - //Add user question with retrieved facts - String groundedUserQuestion = GROUNDED_USER_QUESTION_TEMPLATE.formatted(question,sourcesStringBuilder.toString()); - ChatMessage groundedUserMessage = new ChatMessage(ChatRole.USER); - groundedUserMessage.setContent(groundedUserQuestion); - this.conversationHistory.add(groundedUserMessage); - - return this.conversationHistory; - } - - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.ask.approaches; + +import com.azure.ai.openai.models.ChatMessage; +import com.azure.ai.openai.models.ChatRole; +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import java.util.ArrayList; +import java.util.List; + +public class AnswerQuestionChatTemplate { + + private final List conversationHistory = new ArrayList<>(); + + private String customPrompt = ""; + private String systemMessage; + private Boolean replacePrompt = false; + + private static final String SYSTEM_CHAT_MESSAGE_TEMPLATE = + """ + You are an intelligent assistant helping Contoso Inc employees with their healthcare plan questions and employee handbook questions. + Use 'you' to refer to the individual asking the questions even if they ask with 'I'. + Answer the user question using only the data provided by the user in his message. + For tabular information return it as an html table. Do not return markdown format. + Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. + If you cannot answer say you don't know. + %s + """; + + private static final String FEW_SHOT_USER_MESSAGE = + """ + What is the deductible for the employee plan for a visit to Overlake in Bellevue?' + Sources: + info1.txt: deductibles depend on whether you are in-network or out-of-network. In-network deductibles are $500 for employee and $1000 for family. Out-of-network deductibles are $1000 for employee and $2000 for family. + info2.pdf: Overlake is in-network for the employee plan. + info3.pdf: Overlake is the name of the area that includes a park and ride near Bellevue. + info4.pdf: In-network institutions include Overlake, Swedish and others in the region + """; + private static final String FEW_SHOT_ASSISTANT_MESSAGE = + """ + In-network deductibles are $500 for employee and $1000 for family [info1.txt] and Overlake is in-network for the employee plan [info2.pdf][info4.pdf]. + """; + + /** + * @param conversation conversation history + * @param sources domain specific sources to be used in the prompt + * @param customPrompt custom prompt to be injected in the existing promptTemplate or used to + * replace it + * @param replacePrompt if true, the customPrompt will replace the default promptTemplate, + * otherwise it will be appended to the default promptTemplate in the predefined section + */ + private static final String GROUNDED_USER_QUESTION_TEMPLATE = + """ + %s + Sources: + %s + """; + + public AnswerQuestionChatTemplate(String customPrompt, Boolean replacePrompt) { + + if (replacePrompt && (customPrompt == null || customPrompt.isEmpty())) + throw new IllegalStateException( + "customPrompt cannot be null or empty when replacePrompt is true"); + + this.replacePrompt = replacePrompt; + this.customPrompt = customPrompt == null ? "" : customPrompt; + + if (this.replacePrompt) { + this.systemMessage = customPrompt; + } else { + this.systemMessage = SYSTEM_CHAT_MESSAGE_TEMPLATE.formatted(this.customPrompt); + } + + // Add system message + ChatMessage chatSystemMessage = new ChatMessage(ChatRole.SYSTEM); + chatSystemMessage.setContent(systemMessage); + + this.conversationHistory.add(chatSystemMessage); + + // Add few shoot learning with chat + ChatMessage fewShotUserMessage = new ChatMessage(ChatRole.USER); + fewShotUserMessage.setContent(FEW_SHOT_USER_MESSAGE); + this.conversationHistory.add(fewShotUserMessage); + + ChatMessage fewShotAssistantMessage = new ChatMessage(ChatRole.ASSISTANT); + fewShotAssistantMessage.setContent(FEW_SHOT_ASSISTANT_MESSAGE); + this.conversationHistory.add(fewShotAssistantMessage); + } + + public List getMessages(String question, List sources) { + if (sources == null || sources.isEmpty()) + throw new IllegalStateException("sources cannot be null or empty"); + if (question == null || question.isEmpty()) + throw new IllegalStateException("question cannot be null"); + + StringBuilder sourcesStringBuilder = new StringBuilder(); + // Build sources section + sources.iterator() + .forEachRemaining( + source -> + sourcesStringBuilder + .append(source.getSourceName()) + .append(": ") + .append(source.getSourceContent()) + .append("\n")); + + // Add user question with retrieved facts + String groundedUserQuestion = + GROUNDED_USER_QUESTION_TEMPLATE.formatted( + question, sourcesStringBuilder.toString()); + ChatMessage groundedUserMessage = new ChatMessage(ChatRole.USER); + groundedUserMessage.setContent(groundedUserQuestion); + this.conversationHistory.add(groundedUserMessage); + + return this.conversationHistory; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/PlainJavaAskApproach.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/PlainJavaAskApproach.java index 36bee9f..9137ca3 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/PlainJavaAskApproach.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/PlainJavaAskApproach.java @@ -1,153 +1,174 @@ -package com.microsoft.openai.samples.rag.ask.approaches; - -import com.azure.ai.openai.models.ChatChoice; -import com.azure.ai.openai.models.ChatCompletions; -import com.azure.core.util.IterableStream; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.approaches.RAGApproach; -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import com.microsoft.openai.samples.rag.approaches.RAGResponse; -import com.microsoft.openai.samples.rag.common.ChatGPTUtils; -import com.microsoft.openai.samples.rag.controller.ChatResponse; -import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; -import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider; -import com.microsoft.openai.samples.rag.retrieval.Retriever; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.stereotype.Component; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.List; - -/** - Use Cognitive Search and Java OpenAI APIs. - It first retrieves top documents from search and use them to build a prompt. - Then, it uses OpenAI to generate an answer for the user question. - Several cognitive search retrieval options are available: Text, Vector, Hybrid. - When Hybrid and Vector are selected an additional call to OpenAI is required to generate embeddings vector for the question. - */ -@Component -public class PlainJavaAskApproach implements RAGApproach { - - private static final Logger LOGGER = LoggerFactory.getLogger(PlainJavaAskApproach.class); - private final OpenAIProxy openAIProxy; - private final FactsRetrieverProvider factsRetrieverProvider; - private final ObjectMapper objectMapper; - - public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenAIProxy openAIProxy, ObjectMapper objectMapper) { - this.factsRetrieverProvider = factsRetrieverProvider; - this.openAIProxy = openAIProxy; - this.objectMapper = objectMapper; - } - - /** - * @param question - * @param options - * @return - */ - @Override - public RAGResponse run(String question, RAGOptions options) { - //Get instance of retriever based on the retrieval mode: hybryd, text, vectors. - Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options); - - //STEP 1: Retrieve relevant documents using user question as query - List sources = factsRetriever.retrieveFromQuestion(question, options); - LOGGER.info("Total {} sources found in cognitive search for keyword search query[{}]", sources.size(), - question); - - var customPrompt = options.getPromptTemplate(); - var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty()); - - //true will replace the default prompt. False will add custom prompt as suffix to the default prompt - var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|"); - if (!replacePrompt && !customPromptEmpty) { - customPrompt = customPrompt.substring(1); - } - - //STEP 2: Build a prompt using RAG options to see if prompt should be replaced or extended. - var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt); - - //STEP 3: Build the chat conversation with grounded messages using the retrieved facts - var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources); - var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages); - - // STEP 4: Generate a contextual and content specific answer - ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions); - - LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]", - chatCompletions.getUsage().getPromptTokens(), - chatCompletions.getUsage().getCompletionTokens(), - chatCompletions.getUsage().getTotalTokens()); - - return new RAGResponse.Builder() - .question(question) - .prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages)) - .answer(chatCompletions.getChoices().get(0).getMessage().getContent()) - .sources(sources) - .build(); - } - - @Override - public void runStreaming(String question, RAGOptions options, OutputStream outputStream) { - //Get instance of retriever based on the retrieval mode: hybryd, text, vectors. - Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options); - List sources = factsRetriever.retrieveFromQuestion(question, options); - LOGGER.info("Total {} sources found in cognitive search for keyword search query[{}]", sources.size(), - question); - - var customPrompt = options.getPromptTemplate(); - var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty()); - - //true will replace the default prompt. False will add custom prompt as suffix to the default prompt - var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|"); - if (!replacePrompt && !customPromptEmpty) { - customPrompt = customPrompt.substring(1); - } - - var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt); - - var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources); - var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages); - - IterableStream completions = openAIProxy.getChatCompletionsStream(chatCompletionsOptions); - int index = 0; - for (ChatCompletions completion : completions) { - - LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]", - completion.getUsage().getPromptTokens(), - completion.getUsage().getCompletionTokens(), - completion.getUsage().getTotalTokens()); - - for (ChatChoice choice : completion.getChoices()) { - if (choice.getDelta().getContent() == null) { - continue; - } - - RAGResponse ragResponse = new RAGResponse.Builder() - .question(question) - .prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages)) - .answer(choice.getMessage().getContent()) - .sources(sources) - .build(); - - ChatResponse response; - if (index == 0) { - response = ChatResponse.buildChatResponse(ragResponse); - } else { - response = ChatResponse.buildChatDeltaResponse(index, ragResponse); - } - index++; - - try { - String value = objectMapper.writeValueAsString(response) + "\n"; - outputStream.write(value.getBytes()); - outputStream.flush(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - } - } -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.ask.approaches; + +import com.azure.ai.openai.models.ChatChoice; +import com.azure.ai.openai.models.ChatCompletions; +import com.azure.core.util.IterableStream; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import com.microsoft.openai.samples.rag.approaches.RAGApproach; +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import com.microsoft.openai.samples.rag.approaches.RAGResponse; +import com.microsoft.openai.samples.rag.common.ChatGPTUtils; +import com.microsoft.openai.samples.rag.controller.ChatResponse; +import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; +import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider; +import com.microsoft.openai.samples.rag.retrieval.Retriever; +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +/** + * Use Cognitive Search and Java OpenAI APIs. It first retrieves top documents from search and use + * them to build a prompt. Then, it uses OpenAI to generate an answer for the user question. Several + * cognitive search retrieval options are available: Text, Vector, Hybrid. When Hybrid and Vector + * are selected an additional call to OpenAI is required to generate embeddings vector for the + * question. + */ +@Component +public class PlainJavaAskApproach implements RAGApproach { + + private static final Logger LOGGER = LoggerFactory.getLogger(PlainJavaAskApproach.class); + private final OpenAIProxy openAIProxy; + private final FactsRetrieverProvider factsRetrieverProvider; + private final ObjectMapper objectMapper; + + public PlainJavaAskApproach( + FactsRetrieverProvider factsRetrieverProvider, + OpenAIProxy openAIProxy, + ObjectMapper objectMapper) { + this.factsRetrieverProvider = factsRetrieverProvider; + this.openAIProxy = openAIProxy; + this.objectMapper = objectMapper; + } + + /** + * @param question + * @param options + * @return + */ + @Override + public RAGResponse run(String question, RAGOptions options) { + // Get instance of retriever based on the retrieval mode: hybryd, text, vectors. + Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options); + + // STEP 1: Retrieve relevant documents using user question as query + List sources = factsRetriever.retrieveFromQuestion(question, options); + LOGGER.info( + "Total {} sources found in cognitive search for keyword search query[{}]", + sources.size(), + question); + + var customPrompt = options.getPromptTemplate(); + var customPromptEmpty = + (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty()); + + // true will replace the default prompt. False will add custom prompt as suffix to the + // default prompt + var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|"); + if (!replacePrompt && !customPromptEmpty) { + customPrompt = customPrompt.substring(1); + } + + // STEP 2: Build a prompt using RAG options to see if prompt should be replaced or extended. + var answerQuestionChatTemplate = + new AnswerQuestionChatTemplate(customPrompt, replacePrompt); + + // STEP 3: Build the chat conversation with grounded messages using the retrieved facts + var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources); + var chatCompletionsOptions = + ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages); + + // STEP 4: Generate a contextual and content specific answer + ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions); + + LOGGER.info( + "Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total" + + " Tokens[{}]", + chatCompletions.getUsage().getPromptTokens(), + chatCompletions.getUsage().getCompletionTokens(), + chatCompletions.getUsage().getTotalTokens()); + + return new RAGResponse.Builder() + .question(question) + .prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages)) + .answer(chatCompletions.getChoices().get(0).getMessage().getContent()) + .sources(sources) + .build(); + } + + @Override + public void runStreaming(String question, RAGOptions options, OutputStream outputStream) { + // Get instance of retriever based on the retrieval mode: hybryd, text, vectors. + Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options); + List sources = factsRetriever.retrieveFromQuestion(question, options); + LOGGER.info( + "Total {} sources found in cognitive search for keyword search query[{}]", + sources.size(), + question); + + var customPrompt = options.getPromptTemplate(); + var customPromptEmpty = + (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty()); + + // true will replace the default prompt. False will add custom prompt as suffix to the + // default prompt + var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|"); + if (!replacePrompt && !customPromptEmpty) { + customPrompt = customPrompt.substring(1); + } + + var answerQuestionChatTemplate = + new AnswerQuestionChatTemplate(customPrompt, replacePrompt); + + var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources); + var chatCompletionsOptions = + ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages); + + IterableStream completions = + openAIProxy.getChatCompletionsStream(chatCompletionsOptions); + int index = 0; + for (ChatCompletions completion : completions) { + + LOGGER.info( + "Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}]," + + " Total Tokens[{}]", + completion.getUsage().getPromptTokens(), + completion.getUsage().getCompletionTokens(), + completion.getUsage().getTotalTokens()); + + for (ChatChoice choice : completion.getChoices()) { + if (choice.getDelta().getContent() == null) { + continue; + } + + RAGResponse ragResponse = + new RAGResponse.Builder() + .question(question) + .prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages)) + .answer(choice.getMessage().getContent()) + .sources(sources) + .build(); + + ChatResponse response; + if (index == 0) { + response = ChatResponse.buildChatResponse(ragResponse); + } else { + response = ChatResponse.buildChatDeltaResponse(index, ragResponse); + } + index++; + + try { + String value = objectMapper.writeValueAsString(response) + "\n"; + outputStream.write(value.getBytes()); + outputStream.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/CognitiveSearchPlugin.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/CognitiveSearchPlugin.java index 5eafc92..884db99 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/CognitiveSearchPlugin.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/CognitiveSearchPlugin.java @@ -1,58 +1,60 @@ -package com.microsoft.openai.samples.rag.ask.approaches.semantickernel; - -import com.azure.core.util.Context; -import com.azure.search.documents.SearchDocument; -import com.azure.search.documents.models.*; -import com.azure.search.documents.util.SearchPagedIterable; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; -import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; -import com.microsoft.openai.samples.rag.retrieval.CognitiveSearchRetriever; -import com.microsoft.semantickernel.skilldefinition.annotations.DefineSKFunction; -import com.microsoft.semantickernel.skilldefinition.annotations.SKFunctionInputAttribute; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - -import java.util.List; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicInteger; - - -public class CognitiveSearchPlugin { - - private static final Logger LOGGER = LoggerFactory.getLogger(CognitiveSearchPlugin.class); - private final CognitiveSearchProxy cognitiveSearchProxy; - private final OpenAIProxy openAIProxy; - private final RAGOptions options; - - public CognitiveSearchPlugin(CognitiveSearchProxy cognitiveSearchProxy, OpenAIProxy openAIProxy , RAGOptions options) { - this.cognitiveSearchProxy = cognitiveSearchProxy; - this.options = options; - this.openAIProxy = openAIProxy; - } - - @DefineSKFunction(name = "Search", description = "Search information relevant to answering a given query") - public Mono search( - @SKFunctionInputAttribute(description = "the query to answer") - String query - ) { - - CognitiveSearchRetriever retriever = new CognitiveSearchRetriever(this.cognitiveSearchProxy, this.openAIProxy); - List sources = retriever.retrieveFromQuestion(query, this.options); - - LOGGER.info("Total {} sources found in cognitive search for keyword search query[{}]", sources.size(), - query); - - StringBuilder sourcesStringBuilder = new StringBuilder(); - // Build sources section - sources.iterator().forEachRemaining(source -> sourcesStringBuilder.append( - source.getSourceName()) - .append(": ") - .append(source.getSourceContent().replace("\n", "")) - .append("\n")); - return Mono.just(sourcesStringBuilder.toString()); - } - -} \ No newline at end of file +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.ask.approaches.semantickernel; + +import com.azure.search.documents.models.*; +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; +import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; +import com.microsoft.openai.samples.rag.retrieval.CognitiveSearchRetriever; +import com.microsoft.semantickernel.skilldefinition.annotations.DefineSKFunction; +import com.microsoft.semantickernel.skilldefinition.annotations.SKFunctionInputAttribute; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class CognitiveSearchPlugin { + + private static final Logger LOGGER = LoggerFactory.getLogger(CognitiveSearchPlugin.class); + private final CognitiveSearchProxy cognitiveSearchProxy; + private final OpenAIProxy openAIProxy; + private final RAGOptions options; + + public CognitiveSearchPlugin( + CognitiveSearchProxy cognitiveSearchProxy, + OpenAIProxy openAIProxy, + RAGOptions options) { + this.cognitiveSearchProxy = cognitiveSearchProxy; + this.options = options; + this.openAIProxy = openAIProxy; + } + + @DefineSKFunction( + name = "Search", + description = "Search information relevant to answering a given query") + public Mono search( + @SKFunctionInputAttribute(description = "the query to answer") String query) { + + CognitiveSearchRetriever retriever = + new CognitiveSearchRetriever(this.cognitiveSearchProxy, this.openAIProxy); + List sources = retriever.retrieveFromQuestion(query, this.options); + + LOGGER.info( + "Total {} sources found in cognitive search for keyword search query[{}]", + sources.size(), + query); + + StringBuilder sourcesStringBuilder = new StringBuilder(); + // Build sources section + sources.iterator() + .forEachRemaining( + source -> + sourcesStringBuilder + .append(source.getSourceName()) + .append(": ") + .append(source.getSourceContent().replace("\n", "")) + .append("\n")); + return Mono.just(sourcesStringBuilder.toString()); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelChainsApproach.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelChainsApproach.java index 0195526..4b8ab75 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelChainsApproach.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelChainsApproach.java @@ -1,153 +1,163 @@ -package com.microsoft.openai.samples.rag.ask.approaches.semantickernel; - -import com.azure.ai.openai.OpenAIAsyncClient; -import com.azure.core.annotation.Get; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.approaches.RAGApproach; -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import com.microsoft.openai.samples.rag.approaches.RAGResponse; -import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; -import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; -import com.microsoft.semantickernel.Kernel; -import com.microsoft.semantickernel.SKBuilders; -import com.microsoft.semantickernel.orchestration.SKContext; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; -import reactor.core.publisher.Flux; - -import java.io.OutputStream; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; - -/** - * Use Java Semantic Kernel framework with semantic and native functions chaining. It uses an imperative style for AI orchestration through semantic kernel functions chaining. - * InformationFinder.Search native function and RAG.AnswerQuestion semantic function are called sequentially. - * Several cognitive search retrieval options are available: Text, Vector, Hybrid. - */ -@Component -public class JavaSemanticKernelChainsApproach implements RAGApproach { - - private static final Logger LOGGER = LoggerFactory.getLogger(JavaSemanticKernelChainsApproach.class); - private static final String PLAN_PROMPT = """ - Take the input as a question and answer it finding any information needed - """; - private final CognitiveSearchProxy cognitiveSearchProxy; - - private final OpenAIProxy openAIProxy; - - private final OpenAIAsyncClient openAIAsyncClient; - - @Value("${openai.chatgpt.deployment}") - private String gptChatDeploymentModelId; - - public JavaSemanticKernelChainsApproach(CognitiveSearchProxy cognitiveSearchProxy, OpenAIAsyncClient openAIAsyncClient, OpenAIProxy openAIProxy) { - this.cognitiveSearchProxy = cognitiveSearchProxy; - this.openAIAsyncClient = openAIAsyncClient; - this.openAIProxy = openAIProxy; - } - - /** - * @param question - * @param options - * @return - */ - @Override - public RAGResponse run(String question, RAGOptions options) { - - //Build semantic kernel context - Kernel semanticKernel = buildSemanticKernel(options); - - - //STEP 1: Retrieve relevant documents using user question. It reuses the CognitiveSearchRetriever appraoch through the CognitiveSearchPlugin native function. - SKContext searchContext = - semanticKernel.runAsync( - question, - semanticKernel.getSkill("InformationFinder").getFunction("Search", null)).block(); - - var sources = formSourcesList(searchContext.getResult()); - - //STEP 2: Build a SK context with the sources retrieved from the memory store and the user question. - var answerVariables = SKBuilders.variables() - .withVariable("sources", searchContext.getResult()) - .withVariable("input", question) - .build(); - - /** - * STEP 3: - * Get a reference of the semantic function [AnswerQuestion] of the [RAG] plugin (a.k.a. skill) from the SK skills registry and provide it with the pre-built context. - * Triggering Open AI to get an answerVariables. - */ - SKContext answerExecutionContext = - semanticKernel.runAsync(answerVariables, - semanticKernel.getSkill("RAG").getFunction("AnswerQuestion", null)).block(); - return new RAGResponse.Builder() - .prompt("Prompt is managed by Semantic Kernel") - .answer(answerExecutionContext.getResult()) - .sources(sources) - .sourcesAsText(searchContext.getResult()) - .question(question) - .build(); - - } - - @Override - public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) { - throw new IllegalStateException("Streaming not supported for this approach"); - } - - private List formSourcesList(String result) { - if (result == null) { - return Collections.emptyList(); - } - return Arrays.stream(result - .split("\n")) - .map(source -> { - String[] split = source.split(":", 2); - if (split.length >= 2) { - var sourceName = split[0].trim(); - var sourceContent = split[1].trim(); - return new ContentSource(sourceName, sourceContent); - } else { - return null; - } - }) - .filter(Objects::nonNull) - .collect(Collectors.toList()); - } - - /** - * Build semantic kernel context with AnswerQuestion semantic function and InformationFinder.Search native function. - * AnswerQuestion is imported from src/main/resources/semantickernel/Plugins. - * InformationFinder.Search is implemented in a traditional Java class method: CognitiveSearchPlugin.search - * - * @param options - * @return - */ - private Kernel buildSemanticKernel( RAGOptions options) { - Kernel kernel = SKBuilders.kernel() - .withDefaultAIService(SKBuilders.chatCompletion() - .withModelId(gptChatDeploymentModelId) - .withOpenAIClient(this.openAIAsyncClient) - .build()) - .build(); - - kernel.importSkill(new CognitiveSearchPlugin(this.cognitiveSearchProxy, this.openAIProxy,options), "InformationFinder"); - - kernel.importSkillFromResources( - "semantickernel/Plugins", - "RAG", - "AnswerQuestion", - null - ); - - return kernel; - } - - - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.ask.approaches.semantickernel; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import com.microsoft.openai.samples.rag.approaches.RAGApproach; +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import com.microsoft.openai.samples.rag.approaches.RAGResponse; +import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; +import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; +import com.microsoft.semantickernel.Kernel; +import com.microsoft.semantickernel.SKBuilders; +import com.microsoft.semantickernel.orchestration.SKContext; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +/** + * Use Java Semantic Kernel framework with semantic and native functions chaining. It uses an + * imperative style for AI orchestration through semantic kernel functions chaining. + * InformationFinder.Search native function and RAG.AnswerQuestion semantic function are called + * sequentially. Several cognitive search retrieval options are available: Text, Vector, Hybrid. + */ +@Component +public class JavaSemanticKernelChainsApproach implements RAGApproach { + + private static final Logger LOGGER = + LoggerFactory.getLogger(JavaSemanticKernelChainsApproach.class); + private static final String PLAN_PROMPT = + """ + Take the input as a question and answer it finding any information needed + """; + private final CognitiveSearchProxy cognitiveSearchProxy; + + private final OpenAIProxy openAIProxy; + + private final OpenAIAsyncClient openAIAsyncClient; + + @Value("${openai.chatgpt.deployment}") + private String gptChatDeploymentModelId; + + public JavaSemanticKernelChainsApproach( + CognitiveSearchProxy cognitiveSearchProxy, + OpenAIAsyncClient openAIAsyncClient, + OpenAIProxy openAIProxy) { + this.cognitiveSearchProxy = cognitiveSearchProxy; + this.openAIAsyncClient = openAIAsyncClient; + this.openAIProxy = openAIProxy; + } + + /** + * @param question + * @param options + * @return + */ + @Override + public RAGResponse run(String question, RAGOptions options) { + + // Build semantic kernel context + Kernel semanticKernel = buildSemanticKernel(options); + + // STEP 1: Retrieve relevant documents using user question. It reuses the + // CognitiveSearchRetriever appraoch through the CognitiveSearchPlugin native function. + SKContext searchContext = + semanticKernel + .runAsync( + question, + semanticKernel + .getSkill("InformationFinder") + .getFunction("Search", null)) + .block(); + + var sources = formSourcesList(searchContext.getResult()); + + // STEP 2: Build a SK context with the sources retrieved from the memory store and the user + // question. + var answerVariables = + SKBuilders.variables() + .withVariable("sources", searchContext.getResult()) + .withVariable("input", question) + .build(); + + /** + * STEP 3: Get a reference of the semantic function [AnswerQuestion] of the [RAG] plugin + * (a.k.a. skill) from the SK skills registry and provide it with the pre-built context. + * Triggering Open AI to get an answerVariables. + */ + SKContext answerExecutionContext = + semanticKernel + .runAsync( + answerVariables, + semanticKernel.getSkill("RAG").getFunction("AnswerQuestion", null)) + .block(); + return new RAGResponse.Builder() + .prompt("Prompt is managed by Semantic Kernel") + .answer(answerExecutionContext.getResult()) + .sources(sources) + .sourcesAsText(searchContext.getResult()) + .question(question) + .build(); + } + + @Override + public void runStreaming( + String questionOrConversation, RAGOptions options, OutputStream outputStream) { + throw new IllegalStateException("Streaming not supported for this approach"); + } + + private List formSourcesList(String result) { + if (result == null) { + return Collections.emptyList(); + } + return Arrays.stream(result.split("\n")) + .map( + source -> { + String[] split = source.split(":", 2); + if (split.length >= 2) { + var sourceName = split[0].trim(); + var sourceContent = split[1].trim(); + return new ContentSource(sourceName, sourceContent); + } else { + return null; + } + }) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } + + /** + * Build semantic kernel context with AnswerQuestion semantic function and + * InformationFinder.Search native function. AnswerQuestion is imported from + * src/main/resources/semantickernel/Plugins. InformationFinder.Search is implemented in a + * traditional Java class method: CognitiveSearchPlugin.search + * + * @param options + * @return + */ + private Kernel buildSemanticKernel(RAGOptions options) { + Kernel kernel = + SKBuilders.kernel() + .withDefaultAIService( + SKBuilders.chatCompletion() + .withModelId(gptChatDeploymentModelId) + .withOpenAIClient(this.openAIAsyncClient) + .build()) + .build(); + + kernel.importSkill( + new CognitiveSearchPlugin(this.cognitiveSearchProxy, this.openAIProxy, options), + "InformationFinder"); + + kernel.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerQuestion", null); + + return kernel; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelPlannerApproach.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelPlannerApproach.java index 92accbe..f3e6417 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelPlannerApproach.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelPlannerApproach.java @@ -1,124 +1,124 @@ -package com.microsoft.openai.samples.rag.ask.approaches.semantickernel; - -import com.azure.ai.openai.OpenAIAsyncClient; -import com.microsoft.openai.samples.rag.approaches.RAGApproach; -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import com.microsoft.openai.samples.rag.approaches.RAGResponse; -import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; -import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; -import com.microsoft.semantickernel.Kernel; -import com.microsoft.semantickernel.SKBuilders; -import com.microsoft.semantickernel.orchestration.SKContext; -import com.microsoft.semantickernel.planner.sequentialplanner.SequentialPlanner; -import com.microsoft.semantickernel.planner.sequentialplanner.SequentialPlannerRequestSettings; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; -import reactor.core.publisher.Flux; - -import java.io.OutputStream; -import java.util.Objects; -import java.util.Set; - -/** - * Use Java Semantic Kernel framework with built-in Planner for functions orchestration. - * It uses a declarative style for AI orchestration through the built-in SequentialPlanner. - * SequentialPlanner call OpenAI to generate a plan for answering a question using available plugins: InformationFinder and RAG - */ -@Component -public class JavaSemanticKernelPlannerApproach implements RAGApproach { - - private static final Logger LOGGER = LoggerFactory.getLogger(JavaSemanticKernelPlannerApproach.class); - private static final String GOAL_PROMPT = """ - Take the input as a question and answer it finding any information needed - """; - private final CognitiveSearchProxy cognitiveSearchProxy; - - private final OpenAIProxy openAIProxy; - - private final OpenAIAsyncClient openAIAsyncClient; - - @Value("${openai.chatgpt.deployment}") - private String gptChatDeploymentModelId; - - public JavaSemanticKernelPlannerApproach(CognitiveSearchProxy cognitiveSearchProxy, OpenAIAsyncClient openAIAsyncClient, OpenAIProxy openAIProxy) { - this.cognitiveSearchProxy = cognitiveSearchProxy; - this.openAIAsyncClient = openAIAsyncClient; - this.openAIProxy = openAIProxy; - } - - /** - * @param question - * @param options - * @return - */ - @Override - public RAGResponse run(String question, RAGOptions options) { - - //Build semantic kernel context - Kernel semanticKernel = buildSemanticKernel(options); - - SequentialPlanner sequentialPlanner = new SequentialPlanner(semanticKernel, new SequentialPlannerRequestSettings( - 0.7f, - 100, - Set.of(), - Set.of(), - Set.of(), - 1024 - ), null); - - //STEP 1: ask Open AI to generate an execution plan for the goal contained in GOAL_PROMPT. - var plan = Objects.requireNonNull(sequentialPlanner.createPlanAsync(GOAL_PROMPT).block()); - - LOGGER.debug("Semantic kernel plan calculated is [{}]", plan.toPlanString()); - - //STEP 2: execute the plan calculated by the planner using Open AI - SKContext planContext = Objects.requireNonNull(plan.invokeAsync(question).block()); - - return new RAGResponse.Builder() - .prompt(plan.toPlanString()) - .answer(planContext.getResult()) - //.sourcesAsText(planContext.getVariables().get("sources")) - .sourcesAsText("sources placeholders") - .question(question) - .build(); - - } - - @Override - public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) { - throw new IllegalStateException("Streaming not supported for this approach"); - } - - /** - * Build semantic kernel context with AnswerQuestion semantic function and InformationFinder.Search native function. - * AnswerQuestion is imported from src/main/resources/semantickernel/Plugins. - * InformationFinder.Search is implemented in a traditional Java class method: CognitiveSearchPlugin.search - * - * @param options - * @return - */ - private Kernel buildSemanticKernel( RAGOptions options) { - Kernel kernel = SKBuilders.kernel() - .withDefaultAIService(SKBuilders.chatCompletion() - .withModelId(gptChatDeploymentModelId) - .withOpenAIClient(this.openAIAsyncClient) - .build()) - .build(); - - kernel.importSkill(new CognitiveSearchPlugin(this.cognitiveSearchProxy, this.openAIProxy,options), "InformationFinder"); - - kernel.importSkillFromResources( - "semantickernel/Plugins", - "RAG", - "AnswerQuestion", - null - ); - - return kernel; - } - - - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.ask.approaches.semantickernel; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.microsoft.openai.samples.rag.approaches.RAGApproach; +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import com.microsoft.openai.samples.rag.approaches.RAGResponse; +import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; +import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; +import com.microsoft.semantickernel.Kernel; +import com.microsoft.semantickernel.SKBuilders; +import com.microsoft.semantickernel.orchestration.SKContext; +import com.microsoft.semantickernel.planner.sequentialplanner.SequentialPlanner; +import com.microsoft.semantickernel.planner.sequentialplanner.SequentialPlannerRequestSettings; +import java.io.OutputStream; +import java.util.Objects; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +/** + * Use Java Semantic Kernel framework with built-in Planner for functions orchestration. It uses a + * declarative style for AI orchestration through the built-in SequentialPlanner. SequentialPlanner + * call OpenAI to generate a plan for answering a question using available plugins: + * InformationFinder and RAG + */ +@Component +public class JavaSemanticKernelPlannerApproach implements RAGApproach { + + private static final Logger LOGGER = + LoggerFactory.getLogger(JavaSemanticKernelPlannerApproach.class); + private static final String GOAL_PROMPT = + """ + Take the input as a question and answer it finding any information needed + """; + private final CognitiveSearchProxy cognitiveSearchProxy; + + private final OpenAIProxy openAIProxy; + + private final OpenAIAsyncClient openAIAsyncClient; + + @Value("${openai.chatgpt.deployment}") + private String gptChatDeploymentModelId; + + public JavaSemanticKernelPlannerApproach( + CognitiveSearchProxy cognitiveSearchProxy, + OpenAIAsyncClient openAIAsyncClient, + OpenAIProxy openAIProxy) { + this.cognitiveSearchProxy = cognitiveSearchProxy; + this.openAIAsyncClient = openAIAsyncClient; + this.openAIProxy = openAIProxy; + } + + /** + * @param question + * @param options + * @return + */ + @Override + public RAGResponse run(String question, RAGOptions options) { + + // Build semantic kernel context + Kernel semanticKernel = buildSemanticKernel(options); + + SequentialPlanner sequentialPlanner = + new SequentialPlanner( + semanticKernel, + new SequentialPlannerRequestSettings( + 0.7f, 100, Set.of(), Set.of(), Set.of(), 1024), + null); + + // STEP 1: ask Open AI to generate an execution plan for the goal contained in GOAL_PROMPT. + var plan = Objects.requireNonNull(sequentialPlanner.createPlanAsync(GOAL_PROMPT).block()); + + LOGGER.debug("Semantic kernel plan calculated is [{}]", plan.toPlanString()); + + // STEP 2: execute the plan calculated by the planner using Open AI + SKContext planContext = Objects.requireNonNull(plan.invokeAsync(question).block()); + + return new RAGResponse.Builder() + .prompt(plan.toPlanString()) + .answer(planContext.getResult()) + // .sourcesAsText(planContext.getVariables().get("sources")) + .sourcesAsText("sources placeholders") + .question(question) + .build(); + } + + @Override + public void runStreaming( + String questionOrConversation, RAGOptions options, OutputStream outputStream) { + throw new IllegalStateException("Streaming not supported for this approach"); + } + + /** + * Build semantic kernel context with AnswerQuestion semantic function and + * InformationFinder.Search native function. AnswerQuestion is imported from + * src/main/resources/semantickernel/Plugins. InformationFinder.Search is implemented in a + * traditional Java class method: CognitiveSearchPlugin.search + * + * @param options + * @return + */ + private Kernel buildSemanticKernel(RAGOptions options) { + Kernel kernel = + SKBuilders.kernel() + .withDefaultAIService( + SKBuilders.chatCompletion() + .withModelId(gptChatDeploymentModelId) + .withOpenAIClient(this.openAIAsyncClient) + .build()) + .build(); + + kernel.importSkill( + new CognitiveSearchPlugin(this.cognitiveSearchProxy, this.openAIProxy, options), + "InformationFinder"); + + kernel.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerQuestion", null); + + return kernel; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelWithMemoryApproach.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelWithMemoryApproach.java index 932393b..099efb3 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelWithMemoryApproach.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelWithMemoryApproach.java @@ -1,173 +1,195 @@ -package com.microsoft.openai.samples.rag.ask.approaches.semantickernel; - -import com.azure.ai.openai.OpenAIAsyncClient; -import com.azure.core.credential.TokenCredential; -import com.azure.search.documents.SearchAsyncClient; -import com.azure.search.documents.SearchDocument; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.approaches.RAGApproach; -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import com.microsoft.openai.samples.rag.approaches.RAGResponse; -import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.memory.CustomAzureCognitiveSearchMemoryStore; -import com.microsoft.semantickernel.Kernel; -import com.microsoft.semantickernel.SKBuilders; -import com.microsoft.semantickernel.ai.embeddings.Embedding; -import com.microsoft.semantickernel.memory.MemoryQueryResult; -import com.microsoft.semantickernel.memory.MemoryRecord; -import com.microsoft.semantickernel.orchestration.SKContext; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.io.OutputStream; -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; - -/** - Use Java Semantic Kernel framework with built-in MemoryStore for embeddings similarity search. - A semantic function is defined in RAG.AnswerQuestion (src/main/resources/semantickernel/Plugins) to build the prompt template which is grounded using results from the Memory Store. - A customized version of SK built-in CognitiveSearchMemoryStore is used to map index fields populated by the documents ingestion process. - */ -@Component -public class JavaSemanticKernelWithMemoryApproach implements RAGApproach { - private static final Logger LOGGER = LoggerFactory.getLogger(JavaSemanticKernelWithMemoryApproach.class); - private final TokenCredential tokenCredential; - private final OpenAIAsyncClient openAIAsyncClient; - - private final SearchAsyncClient searchAsyncClient; - - private final String EMBEDDING_FIELD_NAME = "embedding"; - - @Value("${cognitive.search.service}") - String searchServiceName; - @Value("${cognitive.search.index}") - String indexName; - @Value("${openai.chatgpt.deployment}") - private String gptChatDeploymentModelId; - - @Value("${openai.embedding.deployment}") - private String embeddingDeploymentModelId; - - public JavaSemanticKernelWithMemoryApproach(TokenCredential tokenCredential, OpenAIAsyncClient openAIAsyncClient, SearchAsyncClient searchAsyncClient) { - this.tokenCredential = tokenCredential; - this.openAIAsyncClient = openAIAsyncClient; - this.searchAsyncClient = searchAsyncClient; - } - - /** - * @param question - * @param options - * @return - */ - @Override - public RAGResponse run(String question, RAGOptions options) { - - //Build semantic kernel context with Azure Cognitive Search as memory store. AnswerQuestion skill is imported from src/main/resources/semantickernel/Plugins. - Kernel semanticKernel = buildSemanticKernel(options); - - /** - * STEP 1: Retrieve relevant documents using user question - * Use semantic kernel built-in memory.searchAsync. It uses OpenAI to generate embeddings for the provided question. - * Question embeddings are provided to cognitive search via search options. - */ - List memoryResult = semanticKernel.getMemory().searchAsync( - indexName, - question, - options.getTop(), - 0.5f, - false) - .block(); - - LOGGER.info("Total {} sources found in cognitive vector store for search query[{}]", memoryResult.size(), question); - - String sources = buildSourcesText(memoryResult); - List sourcesList = buildSources(memoryResult); - - //STEP 2: Build a SK context with the sources retrieved from the memory store and the user question. - SKContext skcontext = SKBuilders.context().build() - .setVariable("sources", sources) - .setVariable("input", question); - - //STEP 3: Get a reference of the semantic function [AnswerQuestion] of the [RAG] plugin (a.k.a. skill) from the SK skills registry and provide it with the pre-built context. - Mono result = semanticKernel.getFunction("RAG", "AnswerQuestion").invokeAsync(skcontext); - - return new RAGResponse.Builder() - //.prompt(plan.toPlanString()) - .prompt("Prompt is managed by SK and can't be displayed here. See App logs for prompt") - //STEP 4: triggering Open AI to get an answer - .answer(result.block().getResult()) - .sources(sourcesList) - .sourcesAsText(sources) - .question(question) - .build(); - - } - - @Override - public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) { - throw new IllegalStateException("Streaming not supported for this approach"); - } - - private List buildSources(List memoryResult) { - return memoryResult - .stream() - .map(result -> { - return new ContentSource( - result.getMetadata().getId(), - result.getMetadata().getText() - ); - }) - .collect(Collectors.toList()); - } - - private String buildSourcesText(List memoryResult) { - StringBuilder sourcesContentBuffer = new StringBuilder(); - memoryResult.stream().forEach(memory -> { - sourcesContentBuffer.append(memory.getMetadata().getId()) - .append(": ") - .append(memory.getMetadata().getText().replace("\n", "")) - .append("\n"); - }); - return sourcesContentBuffer.toString(); - } - - private Kernel buildSemanticKernel(RAGOptions options) { - var kernelWithACS = SKBuilders.kernel() - .withMemoryStorage( - new CustomAzureCognitiveSearchMemoryStore("https://%s.search.windows.net".formatted(searchServiceName), - tokenCredential, - this.searchAsyncClient, - this.EMBEDDING_FIELD_NAME, - buildCustomMemoryMapper())) - .withDefaultAIService(SKBuilders.textEmbeddingGeneration() - .withOpenAIClient(openAIAsyncClient) - .withModelId(embeddingDeploymentModelId) - .build()) - .withDefaultAIService(SKBuilders.chatCompletion() - .withModelId(gptChatDeploymentModelId) - .withOpenAIClient(this.openAIAsyncClient) - .build()) - .build(); - - kernelWithACS.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerQuestion", null); - return kernelWithACS; - } - - private Function buildCustomMemoryMapper() { - return searchDocument -> { - return MemoryRecord.localRecord( - (String) searchDocument.get("sourcepage"), - (String) searchDocument.get("content"), - "chunked text from original source", - new Embedding((List) searchDocument.get(EMBEDDING_FIELD_NAME)), - (String) searchDocument.get("category"), - (String) searchDocument.get("id"), - null); - - }; - } -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.ask.approaches.semantickernel; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.core.credential.TokenCredential; +import com.azure.search.documents.SearchAsyncClient; +import com.azure.search.documents.SearchDocument; +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import com.microsoft.openai.samples.rag.approaches.RAGApproach; +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import com.microsoft.openai.samples.rag.approaches.RAGResponse; +import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.memory.CustomAzureCognitiveSearchMemoryStore; +import com.microsoft.semantickernel.Kernel; +import com.microsoft.semantickernel.SKBuilders; +import com.microsoft.semantickernel.ai.embeddings.Embedding; +import com.microsoft.semantickernel.memory.MemoryQueryResult; +import com.microsoft.semantickernel.memory.MemoryRecord; +import com.microsoft.semantickernel.orchestration.SKContext; +import java.io.OutputStream; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import reactor.core.publisher.Mono; + +/** + * Use Java Semantic Kernel framework with built-in MemoryStore for embeddings similarity search. A + * semantic function is defined in RAG.AnswerQuestion (src/main/resources/semantickernel/Plugins) to + * build the prompt template which is grounded using results from the Memory Store. A customized + * version of SK built-in CognitiveSearchMemoryStore is used to map index fields populated by the + * documents ingestion process. + */ +@Component +public class JavaSemanticKernelWithMemoryApproach implements RAGApproach { + private static final Logger LOGGER = + LoggerFactory.getLogger(JavaSemanticKernelWithMemoryApproach.class); + private final TokenCredential tokenCredential; + private final OpenAIAsyncClient openAIAsyncClient; + + private final SearchAsyncClient searchAsyncClient; + + private final String EMBEDDING_FIELD_NAME = "embedding"; + + @Value("${cognitive.search.service}") + String searchServiceName; + + @Value("${cognitive.search.index}") + String indexName; + + @Value("${openai.chatgpt.deployment}") + private String gptChatDeploymentModelId; + + @Value("${openai.embedding.deployment}") + private String embeddingDeploymentModelId; + + public JavaSemanticKernelWithMemoryApproach( + TokenCredential tokenCredential, + OpenAIAsyncClient openAIAsyncClient, + SearchAsyncClient searchAsyncClient) { + this.tokenCredential = tokenCredential; + this.openAIAsyncClient = openAIAsyncClient; + this.searchAsyncClient = searchAsyncClient; + } + + /** + * @param question + * @param options + * @return + */ + @Override + public RAGResponse run(String question, RAGOptions options) { + + // Build semantic kernel context with Azure Cognitive Search as memory store. AnswerQuestion + // skill is imported from src/main/resources/semantickernel/Plugins. + Kernel semanticKernel = buildSemanticKernel(options); + + /** + * STEP 1: Retrieve relevant documents using user question Use semantic kernel built-in + * memory.searchAsync. It uses OpenAI to generate embeddings for the provided question. + * Question embeddings are provided to cognitive search via search options. + */ + List memoryResult = + semanticKernel + .getMemory() + .searchAsync(indexName, question, options.getTop(), 0.5f, false) + .block(); + + LOGGER.info( + "Total {} sources found in cognitive vector store for search query[{}]", + memoryResult.size(), + question); + + String sources = buildSourcesText(memoryResult); + List sourcesList = buildSources(memoryResult); + + // STEP 2: Build a SK context with the sources retrieved from the memory store and the user + // question. + SKContext skcontext = + SKBuilders.context() + .build() + .setVariable("sources", sources) + .setVariable("input", question); + + // STEP 3: Get a reference of the semantic function [AnswerQuestion] of the [RAG] plugin + // (a.k.a. skill) from the SK skills registry and provide it with the pre-built context. + Mono result = + semanticKernel.getFunction("RAG", "AnswerQuestion").invokeAsync(skcontext); + + return new RAGResponse.Builder() + // .prompt(plan.toPlanString()) + .prompt( + "Prompt is managed by SK and can't be displayed here. See App logs for" + + " prompt") + // STEP 4: triggering Open AI to get an answer + .answer(result.block().getResult()) + .sources(sourcesList) + .sourcesAsText(sources) + .question(question) + .build(); + } + + @Override + public void runStreaming( + String questionOrConversation, RAGOptions options, OutputStream outputStream) { + throw new IllegalStateException("Streaming not supported for this approach"); + } + + private List buildSources(List memoryResult) { + return memoryResult.stream() + .map( + result -> { + return new ContentSource( + result.getMetadata().getId(), result.getMetadata().getText()); + }) + .collect(Collectors.toList()); + } + + private String buildSourcesText(List memoryResult) { + StringBuilder sourcesContentBuffer = new StringBuilder(); + memoryResult.stream() + .forEach( + memory -> { + sourcesContentBuffer + .append(memory.getMetadata().getId()) + .append(": ") + .append(memory.getMetadata().getText().replace("\n", "")) + .append("\n"); + }); + return sourcesContentBuffer.toString(); + } + + private Kernel buildSemanticKernel(RAGOptions options) { + var kernelWithACS = + SKBuilders.kernel() + .withMemoryStorage( + new CustomAzureCognitiveSearchMemoryStore( + "https://%s.search.windows.net" + .formatted(searchServiceName), + tokenCredential, + this.searchAsyncClient, + this.EMBEDDING_FIELD_NAME, + buildCustomMemoryMapper())) + .withDefaultAIService( + SKBuilders.textEmbeddingGeneration() + .withOpenAIClient(openAIAsyncClient) + .withModelId(embeddingDeploymentModelId) + .build()) + .withDefaultAIService( + SKBuilders.chatCompletion() + .withModelId(gptChatDeploymentModelId) + .withOpenAIClient(this.openAIAsyncClient) + .build()) + .build(); + + kernelWithACS.importSkillFromResources( + "semantickernel/Plugins", "RAG", "AnswerQuestion", null); + return kernelWithACS; + } + + private Function buildCustomMemoryMapper() { + return searchDocument -> { + return MemoryRecord.localRecord( + (String) searchDocument.get("sourcepage"), + (String) searchDocument.get("content"), + "chunked text from original source", + new Embedding((List) searchDocument.get(EMBEDDING_FIELD_NAME)), + (String) searchDocument.get("category"), + (String) searchDocument.get("id"), + null); + }; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/memory/CustomAzureCognitiveSearchMemoryStore.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/memory/CustomAzureCognitiveSearchMemoryStore.java index f33a6ba..efe44e1 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/memory/CustomAzureCognitiveSearchMemoryStore.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/memory/CustomAzureCognitiveSearchMemoryStore.java @@ -1,86 +1,94 @@ -package com.microsoft.openai.samples.rag.ask.approaches.semantickernel.memory; - -import com.azure.core.credential.TokenCredential; -import com.azure.search.documents.SearchAsyncClient; -import com.azure.search.documents.SearchClient; -import com.azure.search.documents.SearchDocument; -import com.azure.search.documents.indexes.SearchIndexClientBuilder; -import com.azure.search.documents.models.SearchOptions; -import com.azure.search.documents.models.SearchQueryVector; -import com.microsoft.semantickernel.ai.embeddings.Embedding; -import com.microsoft.semantickernel.connectors.memory.azurecognitivesearch.AzureCognitiveSearchMemoryRecord; -import com.microsoft.semantickernel.connectors.memory.azurecognitivesearch.AzureCognitiveSearchMemoryStore; -import com.microsoft.semantickernel.memory.MemoryRecord; -import reactor.core.publisher.Mono; -import reactor.util.function.Tuple2; -import reactor.util.function.Tuples; - -import javax.annotation.Nonnull; -import java.util.Collection; -import java.util.function.Function; -import java.util.stream.Collectors; - -public class CustomAzureCognitiveSearchMemoryStore extends AzureCognitiveSearchMemoryStore { - - private SearchAsyncClient searchClient; - private String embeddingFieldMapping = "Embedding"; - - private Function memoryRecordMapper ; - /** - * Create a new instance of custom memory storage using Azure Cognitive Search. - * - * @param endpoint Azure Cognitive Search URI, e.g. "https://contoso.search.windows.net" - * @param credentials Azure service credentials - * @param searchClient Another instance of cognitive search client. Unfortunately this is a hack as - * current getSearchClient is private in parent class. - */ - public CustomAzureCognitiveSearchMemoryStore( - @Nonnull String endpoint, @Nonnull TokenCredential credentials, @Nonnull SearchAsyncClient searchClient, String embeddingFieldMapping) { - super(endpoint, credentials); - this.searchClient = searchClient; - if(embeddingFieldMapping != null && !embeddingFieldMapping.isEmpty()) - this.embeddingFieldMapping = embeddingFieldMapping ; - - } - - public CustomAzureCognitiveSearchMemoryStore( - @Nonnull String endpoint, @Nonnull TokenCredential credentials, @Nonnull SearchAsyncClient searchClient, String embeddingFieldMapping, Function memoryRecordMapper) { - this(endpoint,credentials,searchClient,embeddingFieldMapping); - this.memoryRecordMapper = memoryRecordMapper ; - } - - public Mono>> getNearestMatchesAsync( - @Nonnull String collectionName, - @Nonnull Embedding embedding, - int limit, - float minRelevanceScore, - boolean withEmbedding) { - - SearchQueryVector searchVector = - new SearchQueryVector() - .setKNearestNeighborsCount(limit) - .setFields(embeddingFieldMapping) - .setValue(embedding.getVector()); - - SearchOptions searchOptions = new SearchOptions().setVectors(searchVector); - - return searchClient.search(null, searchOptions) - .filter(result -> (double) minRelevanceScore <= result.getScore()) - .map( - result -> { - MemoryRecord memoryRecord; - //Use default SK mapper if no custom mapper is provided - if(this.memoryRecordMapper == null) { - memoryRecord = result.getDocument(AzureCognitiveSearchMemoryRecord.class) - .toMemoryRecord(withEmbedding); - } else { - memoryRecord = this.memoryRecordMapper.apply(result.getDocument(SearchDocument.class)); - } - - float score = (float) result.getScore(); - return Tuples.of(memoryRecord, score); - }) - .collect(Collectors.toList()); - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.ask.approaches.semantickernel.memory; + +import com.azure.core.credential.TokenCredential; +import com.azure.search.documents.SearchAsyncClient; +import com.azure.search.documents.SearchDocument; +import com.azure.search.documents.models.SearchOptions; +import com.azure.search.documents.models.SearchQueryVector; +import com.microsoft.semantickernel.ai.embeddings.Embedding; +import com.microsoft.semantickernel.connectors.memory.azurecognitivesearch.AzureCognitiveSearchMemoryRecord; +import com.microsoft.semantickernel.connectors.memory.azurecognitivesearch.AzureCognitiveSearchMemoryStore; +import com.microsoft.semantickernel.memory.MemoryRecord; +import java.util.Collection; +import java.util.function.Function; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +public class CustomAzureCognitiveSearchMemoryStore extends AzureCognitiveSearchMemoryStore { + + private SearchAsyncClient searchClient; + private String embeddingFieldMapping = "Embedding"; + + private Function memoryRecordMapper; + + /** + * Create a new instance of custom memory storage using Azure Cognitive Search. + * + * @param endpoint Azure Cognitive Search URI, e.g. "https://contoso.search.windows.net" + * @param credentials Azure service credentials + * @param searchClient Another instance of cognitive search client. Unfortunately this is a hack + * as current getSearchClient is private in parent class. + */ + public CustomAzureCognitiveSearchMemoryStore( + @Nonnull String endpoint, + @Nonnull TokenCredential credentials, + @Nonnull SearchAsyncClient searchClient, + String embeddingFieldMapping) { + super(endpoint, credentials); + this.searchClient = searchClient; + if (embeddingFieldMapping != null && !embeddingFieldMapping.isEmpty()) + this.embeddingFieldMapping = embeddingFieldMapping; + } + + public CustomAzureCognitiveSearchMemoryStore( + @Nonnull String endpoint, + @Nonnull TokenCredential credentials, + @Nonnull SearchAsyncClient searchClient, + String embeddingFieldMapping, + Function memoryRecordMapper) { + this(endpoint, credentials, searchClient, embeddingFieldMapping); + this.memoryRecordMapper = memoryRecordMapper; + } + + public Mono>> getNearestMatchesAsync( + @Nonnull String collectionName, + @Nonnull Embedding embedding, + int limit, + float minRelevanceScore, + boolean withEmbedding) { + + SearchQueryVector searchVector = + new SearchQueryVector() + .setKNearestNeighborsCount(limit) + .setFields(embeddingFieldMapping) + .setValue(embedding.getVector()); + + SearchOptions searchOptions = new SearchOptions().setVectors(searchVector); + + return searchClient + .search(null, searchOptions) + .filter(result -> (double) minRelevanceScore <= result.getScore()) + .map( + result -> { + MemoryRecord memoryRecord; + // Use default SK mapper if no custom mapper is provided + if (this.memoryRecordMapper == null) { + memoryRecord = + result.getDocument(AzureCognitiveSearchMemoryRecord.class) + .toMemoryRecord(withEmbedding); + } else { + memoryRecord = + this.memoryRecordMapper.apply( + result.getDocument(SearchDocument.class)); + } + + float score = (float) result.getScore(); + return Tuples.of(memoryRecord, score); + }) + .collect(Collectors.toList()); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/controller/AskController.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/controller/AskController.java index c2ae564..01c0167 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/controller/AskController.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/controller/AskController.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.ask.controller; import com.microsoft.openai.samples.rag.approaches.RAGApproach; @@ -29,20 +30,23 @@ public class AskController { this.ragApproachFactory = ragApproachFactory; } - @PostMapping( - value = "/api/ask", - produces = MediaType.APPLICATION_NDJSON_VALUE - ) - public ResponseEntity openAIAskStream( - @RequestBody ChatAppRequest askRequest - ) { + @PostMapping(value = "/api/ask", produces = MediaType.APPLICATION_NDJSON_VALUE) + public ResponseEntity openAIAskStream(@RequestBody ChatAppRequest askRequest) { if (!askRequest.stream()) { - LOGGER.warn("Requested a content-type of application/ndjson however did not requested streaming. Please use a content-type of application/json"); - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Requested a content-type of application/ndjson however did not requested streaming. Please use a content-type of application/json"); + LOGGER.warn( + "Requested a content-type of application/ndjson however did not requested" + + " streaming. Please use a content-type of application/json"); + throw new ResponseStatusException( + HttpStatus.BAD_REQUEST, + "Requested a content-type of application/ndjson however did not requested" + + " streaming. Please use a content-type of application/json"); } String question = askRequest.messages().get(askRequest.messages().size() - 1).content(); - LOGGER.info("Received request for ask api with question [{}] and approach[{}]", question, askRequest.approach()); + LOGGER.info( + "Received request for ask api with question [{}] and approach[{}]", + question, + askRequest.approach()); if (!StringUtils.hasText(askRequest.approach())) { LOGGER.warn("approach cannot be null in ASK request"); @@ -54,41 +58,50 @@ public ResponseEntity openAIAskStream( return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null); } - var ragOptions = new RAGOptions.Builder() - .retrievialMode(askRequest.context().overrides().retrieval_mode().name()) - .semanticKernelMode(askRequest.context().overrides().semantic_kernel_mode()) - .semanticRanker(askRequest.context().overrides().semantic_ranker()) - .semanticCaptions(askRequest.context().overrides().semantic_captions()) - .excludeCategory(askRequest.context().overrides().exclude_category()) - .promptTemplate(askRequest.context().overrides().prompt_template()) - .top(askRequest.context().overrides().top()) - .build(); - - RAGApproach ragApproach = ragApproachFactory.createApproach(askRequest.approach(), RAGType.ASK, ragOptions); - - StreamingResponseBody response = output -> { - try { - ragApproach.runStreaming(question, ragOptions, output); - } finally { - output.flush(); - output.close(); - } - }; - - return ResponseEntity.ok() - .contentType(MediaType.APPLICATION_NDJSON) - .body(response); + var ragOptions = + new RAGOptions.Builder() + .retrievialMode(askRequest.context().overrides().retrieval_mode().name()) + .semanticKernelMode(askRequest.context().overrides().semantic_kernel_mode()) + .semanticRanker(askRequest.context().overrides().semantic_ranker()) + .semanticCaptions(askRequest.context().overrides().semantic_captions()) + .excludeCategory(askRequest.context().overrides().exclude_category()) + .promptTemplate(askRequest.context().overrides().prompt_template()) + .top(askRequest.context().overrides().top()) + .build(); + + RAGApproach ragApproach = + ragApproachFactory.createApproach(askRequest.approach(), RAGType.ASK, ragOptions); + + StreamingResponseBody response = + output -> { + try { + ragApproach.runStreaming(question, ragOptions, output); + } finally { + output.flush(); + output.close(); + } + }; + + return ResponseEntity.ok().contentType(MediaType.APPLICATION_NDJSON).body(response); } @PostMapping("/api/ask") public ResponseEntity openAIAsk(@RequestBody ChatAppRequest askRequest) { if (askRequest.stream()) { - LOGGER.warn("Requested a content-type of application/json however also requested streaming. Please use a content-type of application/ndjson"); - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Requested a content-type of application/json however also requested streaming. Please use a content-type of application/ndjson"); + LOGGER.warn( + "Requested a content-type of application/json however also requested streaming." + + " Please use a content-type of application/ndjson"); + throw new ResponseStatusException( + HttpStatus.BAD_REQUEST, + "Requested a content-type of application/json however also requested streaming." + + " Please use a content-type of application/ndjson"); } String question = askRequest.messages().get(askRequest.messages().size() - 1).content(); - LOGGER.info("Received request for ask api with question [{}] and approach[{}]", question, askRequest.approach()); + LOGGER.info( + "Received request for ask api with question [{}] and approach[{}]", + question, + askRequest.approach()); if (!StringUtils.hasText(askRequest.approach())) { LOGGER.warn("approach cannot be null in ASK request"); @@ -100,19 +113,21 @@ public ResponseEntity openAIAsk(@RequestBody ChatAppRequest askReq return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null); } - var ragOptions = new RAGOptions.Builder() - .retrievialMode(askRequest.context().overrides().retrieval_mode().name()) - .semanticKernelMode(askRequest.context().overrides().semantic_kernel_mode()) - .semanticRanker(askRequest.context().overrides().semantic_ranker()) - .semanticCaptions(askRequest.context().overrides().semantic_captions()) - .excludeCategory(askRequest.context().overrides().exclude_category()) - .promptTemplate(askRequest.context().overrides().prompt_template()) - .top(askRequest.context().overrides().top()) - .build(); - - RAGApproach ragApproach = ragApproachFactory.createApproach(askRequest.approach(), RAGType.ASK, ragOptions); - - - return ResponseEntity.ok(ChatResponse.buildChatResponse(ragApproach.run(question, ragOptions))); + var ragOptions = + new RAGOptions.Builder() + .retrievialMode(askRequest.context().overrides().retrieval_mode().name()) + .semanticKernelMode(askRequest.context().overrides().semantic_kernel_mode()) + .semanticRanker(askRequest.context().overrides().semantic_ranker()) + .semanticCaptions(askRequest.context().overrides().semantic_captions()) + .excludeCategory(askRequest.context().overrides().exclude_category()) + .promptTemplate(askRequest.context().overrides().prompt_template()) + .top(askRequest.context().overrides().top()) + .build(); + + RAGApproach ragApproach = + ragApproachFactory.createApproach(askRequest.approach(), RAGType.ASK, ragOptions); + + return ResponseEntity.ok( + ChatResponse.buildChatResponse(ragApproach.run(question, ragOptions))); } -} \ No newline at end of file +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/PlainJavaChatApproach.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/PlainJavaChatApproach.java index f129efd..4a5231d 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/PlainJavaChatApproach.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/PlainJavaChatApproach.java @@ -1,142 +1,171 @@ -package com.microsoft.openai.samples.rag.chat.approaches; - -import com.azure.ai.openai.models.ChatChoice; -import com.azure.ai.openai.models.ChatCompletions; -import com.azure.core.util.IterableStream; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.approaches.RAGApproach; -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import com.microsoft.openai.samples.rag.approaches.RAGResponse; -import com.microsoft.openai.samples.rag.common.ChatGPTConversation; -import com.microsoft.openai.samples.rag.common.ChatGPTUtils; -import com.microsoft.openai.samples.rag.controller.ChatResponse; -import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; -import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider; -import com.microsoft.openai.samples.rag.retrieval.Retriever; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.context.ApplicationContext; -import org.springframework.stereotype.Component; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.List; - -/** - * Simple chat-read-retrieve-read java implementation, using the Cognitive Search and OpenAI APIs directly. - * It first calls OpenAI to generate a search keyword for the chat history and then answer to the last chat question. - * Several cognitive search retrieval options are available: Text, Vector, Hybrid. - * When Hybrid and Vector are selected an additional call to OpenAI is required to generate embeddings vector for the chat extracted keywords. - */ -@Component -public class PlainJavaChatApproach implements RAGApproach { - - private static final Logger LOGGER = LoggerFactory.getLogger(PlainJavaChatApproach.class); - private final ObjectMapper objectMapper; - private ApplicationContext applicationContext; - private final OpenAIProxy openAIProxy; - private final FactsRetrieverProvider factsRetrieverProvider; - - public PlainJavaChatApproach( - FactsRetrieverProvider factsRetrieverProvider, - OpenAIProxy openAIProxy, - ObjectMapper objectMapper) { - this.factsRetrieverProvider = factsRetrieverProvider; - this.openAIProxy = openAIProxy; - this.objectMapper = objectMapper; - } - - /** - * @param questionOrConversation - * @param options - * @return - */ - @Override - public RAGResponse run(ChatGPTConversation questionOrConversation, RAGOptions options) { - //Get instance of retriever based on the retrieval mode: hybryd, text, vectors. - Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options); - - //STEP 1: Retrieve relevant documents using kewirds extracted from the chat history. An additional call to OpenAI is required to generate keywords. - List sources = factsRetriever.retrieveFromConversation(questionOrConversation, options); - LOGGER.info("Total {} sources retrieved", sources.size()); - - - //STEP 2: Build a grounded prompt using the retrieved documents. RAG options is used to configure additional prompt extension like 'suggesting follow up questions' option. - var semanticSearchChat = new SemanticSearchChat(questionOrConversation, sources, options.getPromptTemplate(), false, options.isSuggestFollowupQuestions()); - var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(semanticSearchChat.getMessages()); - - // STEP 3: Generate a contextual and content specific answer using the search results and chat history - ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions); - - LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]", - chatCompletions.getUsage().getPromptTokens(), - chatCompletions.getUsage().getCompletionTokens(), - chatCompletions.getUsage().getTotalTokens()); - - return new RAGResponse.Builder() - .question(ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages())) - .prompt(ChatGPTUtils.formatAsChatML(semanticSearchChat.getMessages())) - .answer(chatCompletions.getChoices().get(0).getMessage().getContent()) - .sources(sources) - .build(); - } - - @Override - public void runStreaming( - ChatGPTConversation questionOrConversation, - RAGOptions options, - OutputStream outputStream) { - Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options); - List sources = factsRetriever.retrieveFromConversation(questionOrConversation, options); - LOGGER.info("Total {} sources retrieved", sources.size()); - - // Replace whole prompt is not supported yet - var semanticSearchChat = new SemanticSearchChat(questionOrConversation, sources, options.getPromptTemplate(), false, options.isSuggestFollowupQuestions()); - var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(semanticSearchChat.getMessages()); - - int index = 0; - - IterableStream completions = openAIProxy.getChatCompletionsStream(chatCompletionsOptions); - - for (ChatCompletions completion : completions) { - if (completion.getUsage() != null) { - LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]", - completion.getUsage().getPromptTokens(), - completion.getUsage().getCompletionTokens(), - completion.getUsage().getTotalTokens()); - } - - List choices = completion.getChoices(); - - for (ChatChoice choice : choices) { - if (choice.getDelta().getContent() == null) { - continue; - } - - RAGResponse ragResponse = new RAGResponse.Builder() - .question(ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages())) - .prompt(ChatGPTUtils.formatAsChatML(semanticSearchChat.getMessages())) - .answer(choice.getDelta().getContent()) - .sources(sources) - .build(); - - ChatResponse response; - if (index == 0) { - response = ChatResponse.buildChatResponse(ragResponse); - } else { - response = ChatResponse.buildChatDeltaResponse(index, ragResponse); - } - index++; - - try { - String value = objectMapper.writeValueAsString(response) + "\n"; - outputStream.write(value.getBytes()); - outputStream.flush(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - } - } -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.chat.approaches; + +import com.azure.ai.openai.models.ChatChoice; +import com.azure.ai.openai.models.ChatCompletions; +import com.azure.core.util.IterableStream; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import com.microsoft.openai.samples.rag.approaches.RAGApproach; +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import com.microsoft.openai.samples.rag.approaches.RAGResponse; +import com.microsoft.openai.samples.rag.common.ChatGPTConversation; +import com.microsoft.openai.samples.rag.common.ChatGPTUtils; +import com.microsoft.openai.samples.rag.controller.ChatResponse; +import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; +import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider; +import com.microsoft.openai.samples.rag.retrieval.Retriever; +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.context.ApplicationContext; +import org.springframework.stereotype.Component; + +/** + * Simple chat-read-retrieve-read java implementation, using the Cognitive Search and OpenAI APIs + * directly. It first calls OpenAI to generate a search keyword for the chat history and then answer + * to the last chat question. Several cognitive search retrieval options are available: Text, + * Vector, Hybrid. When Hybrid and Vector are selected an additional call to OpenAI is required to + * generate embeddings vector for the chat extracted keywords. + */ +@Component +public class PlainJavaChatApproach implements RAGApproach { + + private static final Logger LOGGER = LoggerFactory.getLogger(PlainJavaChatApproach.class); + private final ObjectMapper objectMapper; + private ApplicationContext applicationContext; + private final OpenAIProxy openAIProxy; + private final FactsRetrieverProvider factsRetrieverProvider; + + public PlainJavaChatApproach( + FactsRetrieverProvider factsRetrieverProvider, + OpenAIProxy openAIProxy, + ObjectMapper objectMapper) { + this.factsRetrieverProvider = factsRetrieverProvider; + this.openAIProxy = openAIProxy; + this.objectMapper = objectMapper; + } + + /** + * @param questionOrConversation + * @param options + * @return + */ + @Override + public RAGResponse run(ChatGPTConversation questionOrConversation, RAGOptions options) { + // Get instance of retriever based on the retrieval mode: hybryd, text, vectors. + Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options); + + // STEP 1: Retrieve relevant documents using kewirds extracted from the chat history. An + // additional call to OpenAI is required to generate keywords. + List sources = + factsRetriever.retrieveFromConversation(questionOrConversation, options); + LOGGER.info("Total {} sources retrieved", sources.size()); + + // STEP 2: Build a grounded prompt using the retrieved documents. RAG options is used to + // configure additional prompt extension like 'suggesting follow up questions' option. + var semanticSearchChat = + new SemanticSearchChat( + questionOrConversation, + sources, + options.getPromptTemplate(), + false, + options.isSuggestFollowupQuestions()); + var chatCompletionsOptions = + ChatGPTUtils.buildDefaultChatCompletionsOptions(semanticSearchChat.getMessages()); + + // STEP 3: Generate a contextual and content specific answer using the search results and + // chat history + ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions); + + LOGGER.info( + "Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total" + + " Tokens[{}]", + chatCompletions.getUsage().getPromptTokens(), + chatCompletions.getUsage().getCompletionTokens(), + chatCompletions.getUsage().getTotalTokens()); + + return new RAGResponse.Builder() + .question(ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages())) + .prompt(ChatGPTUtils.formatAsChatML(semanticSearchChat.getMessages())) + .answer(chatCompletions.getChoices().get(0).getMessage().getContent()) + .sources(sources) + .build(); + } + + @Override + public void runStreaming( + ChatGPTConversation questionOrConversation, + RAGOptions options, + OutputStream outputStream) { + Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options); + List sources = + factsRetriever.retrieveFromConversation(questionOrConversation, options); + LOGGER.info("Total {} sources retrieved", sources.size()); + + // Replace whole prompt is not supported yet + var semanticSearchChat = + new SemanticSearchChat( + questionOrConversation, + sources, + options.getPromptTemplate(), + false, + options.isSuggestFollowupQuestions()); + var chatCompletionsOptions = + ChatGPTUtils.buildDefaultChatCompletionsOptions(semanticSearchChat.getMessages()); + + int index = 0; + + IterableStream completions = + openAIProxy.getChatCompletionsStream(chatCompletionsOptions); + + for (ChatCompletions completion : completions) { + if (completion.getUsage() != null) { + LOGGER.info( + "Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}]," + + " Total Tokens[{}]", + completion.getUsage().getPromptTokens(), + completion.getUsage().getCompletionTokens(), + completion.getUsage().getTotalTokens()); + } + + List choices = completion.getChoices(); + + for (ChatChoice choice : choices) { + if (choice.getDelta().getContent() == null) { + continue; + } + + RAGResponse ragResponse = + new RAGResponse.Builder() + .question( + ChatGPTUtils.getLastUserQuestion( + questionOrConversation.getMessages())) + .prompt( + ChatGPTUtils.formatAsChatML( + semanticSearchChat.getMessages())) + .answer(choice.getDelta().getContent()) + .sources(sources) + .build(); + + ChatResponse response; + if (index == 0) { + response = ChatResponse.buildChatResponse(ragResponse); + } else { + response = ChatResponse.buildChatDeltaResponse(index, ragResponse); + } + index++; + + try { + String value = objectMapper.writeValueAsString(response) + "\n"; + outputStream.write(value.getBytes()); + outputStream.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/SemanticSearchChat.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/SemanticSearchChat.java index 8047798..257fcb6 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/SemanticSearchChat.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/SemanticSearchChat.java @@ -1,114 +1,138 @@ -package com.microsoft.openai.samples.rag.chat.approaches; - -import com.azure.ai.openai.models.ChatMessage; -import com.azure.ai.openai.models.ChatRole; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.common.ChatGPTConversation; -import com.microsoft.openai.samples.rag.common.ChatGPTMessage; - -import java.util.ArrayList; -import java.util.List; - -public class SemanticSearchChat { - - private final List conversationHistory = new ArrayList<>(); - private final StringBuilder sources = new StringBuilder(); - private final Boolean followUpQuestions; - private final String customPrompt; - private final String systemMessage; - private Boolean replacePrompt = false; - - private static final String FOLLOW_UP_QUESTIONS_TEMPLATE = """ - After answering question, also generate three very brief follow-up questions that the user would likely ask next. - Use double angle brackets to reference the questions, e.g. <>. - Try not to repeat questions that have already been asked. - Only generate questions and do not generate any text before or after the questions, such as 'Next Questions' - """; - private static final String SYSTEM_CHAT_MESSAGE_TEMPLATE = """ - Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers. - Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question. - For tabular information return it as an html table. Do not return markdown format. - Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, e.g. [info1.txt]. Don't combine sources, list each source separately, e.g. [info1.txt][info2.pdf]. - - %s - %s - Sources: - %s - """ ; - - /** - * - * @param conversation conversation history - * @param sources domain specific sources to be used in the prompt - * @param customPrompt custom prompt to be injected in the existing promptTemplate or used to replace it - * @param replacePrompt if true, the customPrompt will replace the default promptTemplate, otherwise it will be appended - * to the default promptTemplate in the predefined section - */ - public SemanticSearchChat(ChatGPTConversation conversation, List sources, String customPrompt, Boolean replacePrompt, Boolean followUpQuestions) { - if (conversation == null || conversation.getMessages().isEmpty()) - throw new IllegalStateException("conversation cannot be null or empty"); - if (sources == null) - throw new IllegalStateException("sources cannot be null"); - - if(replacePrompt) - throw new IllegalStateException("replace prompt is not supported yet. please set it to false when custom prompt is provided"); - - if(replacePrompt && (customPrompt == null || customPrompt.isEmpty())) - throw new IllegalStateException("customPrompt cannot be null or empty when replacePrompt is true"); - - this.followUpQuestions = followUpQuestions; - this.replacePrompt = replacePrompt; - this.customPrompt = customPrompt == null ? "" : customPrompt; - - // Build sources section - sources.iterator().forEachRemaining(source -> this.sources.append(source.getSourceName()).append(": ").append(source.getSourceContent()).append("\n")); - - this.systemMessage = SYSTEM_CHAT_MESSAGE_TEMPLATE.formatted(this.followUpQuestions ? FOLLOW_UP_QUESTIONS_TEMPLATE : "",this.customPrompt,this.sources.toString()); - - //Add system message - ChatMessage chatMessage = new ChatMessage(ChatRole.SYSTEM); - chatMessage.setContent(systemMessage); - this.conversationHistory.add(chatMessage); - - buildConversationHistory(conversation); - } - - /** - * - * @param conversation conversation history - * @param sources domain specific sources to be used in the prompt - */ - public SemanticSearchChat(ChatGPTConversation conversation, List sources) { - this(conversation, sources, null, false,false); - } - - /** - * - * @param conversation conversation history - * @param sources domain specific sources to be used in the prompt - * @param followupQuestions if true, the followup questions prompt will be injected in the promptTemplate - */ - public SemanticSearchChat(ChatGPTConversation conversation, List sources, Boolean followupQuestions) { - this(conversation, sources, null, false,followupQuestions); - } - - public List getMessages() { - return this.conversationHistory; - } - - private void buildConversationHistory(ChatGPTConversation conversation) { - // Build conversation history is the rest of the messages - conversation.getMessages().forEach(message -> { - if(message.role() == ChatGPTMessage.ChatRole.USER){ - ChatMessage chatMessage = new ChatMessage(ChatRole.USER); - chatMessage.setContent(message.content()); - this.conversationHistory.add(chatMessage); - } else if(message.role() == ChatGPTMessage.ChatRole.ASSISTANT) { - ChatMessage chatMessage = new ChatMessage(ChatRole.ASSISTANT); - chatMessage.setContent(message.content()); - this.conversationHistory.add(chatMessage); - } - }); - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.chat.approaches; + +import com.azure.ai.openai.models.ChatMessage; +import com.azure.ai.openai.models.ChatRole; +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import com.microsoft.openai.samples.rag.common.ChatGPTConversation; +import com.microsoft.openai.samples.rag.common.ChatGPTMessage; +import java.util.ArrayList; +import java.util.List; + +public class SemanticSearchChat { + + private final List conversationHistory = new ArrayList<>(); + private final StringBuilder sources = new StringBuilder(); + private final Boolean followUpQuestions; + private final String customPrompt; + private final String systemMessage; + private Boolean replacePrompt = false; + + private static final String FOLLOW_UP_QUESTIONS_TEMPLATE = + """ + After answering question, also generate three very brief follow-up questions that the user would likely ask next. + Use double angle brackets to reference the questions, e.g. <>. + Try not to repeat questions that have already been asked. + Only generate questions and do not generate any text before or after the questions, such as 'Next Questions' + """; + private static final String SYSTEM_CHAT_MESSAGE_TEMPLATE = + """ + Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers. + Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question. + For tabular information return it as an html table. Do not return markdown format. + Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, e.g. [info1.txt]. Don't combine sources, list each source separately, e.g. [info1.txt][info2.pdf]. + + %s + %s + Sources: + %s + """; + + /** + * @param conversation conversation history + * @param sources domain specific sources to be used in the prompt + * @param customPrompt custom prompt to be injected in the existing promptTemplate or used to + * replace it + * @param replacePrompt if true, the customPrompt will replace the default promptTemplate, + * otherwise it will be appended to the default promptTemplate in the predefined section + */ + public SemanticSearchChat( + ChatGPTConversation conversation, + List sources, + String customPrompt, + Boolean replacePrompt, + Boolean followUpQuestions) { + if (conversation == null || conversation.getMessages().isEmpty()) + throw new IllegalStateException("conversation cannot be null or empty"); + if (sources == null) throw new IllegalStateException("sources cannot be null"); + + if (replacePrompt) + throw new IllegalStateException( + "replace prompt is not supported yet. please set it to false when custom prompt" + + " is provided"); + + if (replacePrompt && (customPrompt == null || customPrompt.isEmpty())) + throw new IllegalStateException( + "customPrompt cannot be null or empty when replacePrompt is true"); + + this.followUpQuestions = followUpQuestions; + this.replacePrompt = replacePrompt; + this.customPrompt = customPrompt == null ? "" : customPrompt; + + // Build sources section + sources.iterator() + .forEachRemaining( + source -> + this.sources + .append(source.getSourceName()) + .append(": ") + .append(source.getSourceContent()) + .append("\n")); + + this.systemMessage = + SYSTEM_CHAT_MESSAGE_TEMPLATE.formatted( + this.followUpQuestions ? FOLLOW_UP_QUESTIONS_TEMPLATE : "", + this.customPrompt, + this.sources.toString()); + + // Add system message + ChatMessage chatMessage = new ChatMessage(ChatRole.SYSTEM); + chatMessage.setContent(systemMessage); + this.conversationHistory.add(chatMessage); + + buildConversationHistory(conversation); + } + + /** + * @param conversation conversation history + * @param sources domain specific sources to be used in the prompt + */ + public SemanticSearchChat(ChatGPTConversation conversation, List sources) { + this(conversation, sources, null, false, false); + } + + /** + * @param conversation conversation history + * @param sources domain specific sources to be used in the prompt + * @param followupQuestions if true, the followup questions prompt will be injected in the + * promptTemplate + */ + public SemanticSearchChat( + ChatGPTConversation conversation, + List sources, + Boolean followupQuestions) { + this(conversation, sources, null, false, followupQuestions); + } + + public List getMessages() { + return this.conversationHistory; + } + + private void buildConversationHistory(ChatGPTConversation conversation) { + // Build conversation history is the rest of the messages + conversation + .getMessages() + .forEach( + message -> { + if (message.role() == ChatGPTMessage.ChatRole.USER) { + ChatMessage chatMessage = new ChatMessage(ChatRole.USER); + chatMessage.setContent(message.content()); + this.conversationHistory.add(chatMessage); + } else if (message.role() == ChatGPTMessage.ChatRole.ASSISTANT) { + ChatMessage chatMessage = new ChatMessage(ChatRole.ASSISTANT); + chatMessage.setContent(message.content()); + this.conversationHistory.add(chatMessage); + } + }); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/controller/ChatController.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/controller/ChatController.java index 22e029e..a3454db 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/controller/ChatController.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/controller/ChatController.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.chat.controller; import com.microsoft.openai.samples.rag.approaches.RAGApproach; @@ -10,6 +11,9 @@ import com.microsoft.openai.samples.rag.controller.ChatAppRequest; import com.microsoft.openai.samples.rag.controller.ChatResponse; import com.microsoft.openai.samples.rag.controller.ResponseMessage; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; @@ -22,10 +26,6 @@ import org.springframework.web.server.ResponseStatusException; import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - @RestController public class ChatController { @@ -36,16 +36,17 @@ public ChatController(RAGApproachFactory ragAp this.ragApproachFactory = ragApproachFactory; } - @PostMapping( - value = "/api/chat", - produces = MediaType.APPLICATION_NDJSON_VALUE - ) + @PostMapping(value = "/api/chat", produces = MediaType.APPLICATION_NDJSON_VALUE) public ResponseEntity openAIAskStream( - @RequestBody ChatAppRequest chatRequest - ) { + @RequestBody ChatAppRequest chatRequest) { if (!chatRequest.stream()) { - LOGGER.warn("Requested a content-type of application/ndjson however did not requested streaming. Please use a content-type of application/json"); - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Requested a content-type of application/ndjson however did not requested streaming. Please use a content-type of application/json"); + LOGGER.warn( + "Requested a content-type of application/ndjson however did not requested" + + " streaming. Please use a content-type of application/json"); + throw new ResponseStatusException( + HttpStatus.BAD_REQUEST, + "Requested a content-type of application/ndjson however did not requested" + + " streaming. Please use a content-type of application/json"); } LOGGER.info("Received request for chat api with approach[{}]", chatRequest.approach()); @@ -60,42 +61,46 @@ public ResponseEntity openAIAskStream( return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null); } - var ragOptions = new RAGOptions.Builder() - .retrievialMode(chatRequest.context().overrides().retrieval_mode().name()) - .semanticRanker(chatRequest.context().overrides().semantic_ranker()) - .semanticCaptions(chatRequest.context().overrides().semantic_captions()) - .suggestFollowupQuestions(chatRequest.context().overrides().suggest_followup_questions()) - .excludeCategory(chatRequest.context().overrides().exclude_category()) - .promptTemplate(chatRequest.context().overrides().prompt_template()) - .top(chatRequest.context().overrides().top()) - .build(); - - RAGApproach ragApproach = ragApproachFactory.createApproach(chatRequest.approach(), RAGType.CHAT, ragOptions); + var ragOptions = + new RAGOptions.Builder() + .retrievialMode(chatRequest.context().overrides().retrieval_mode().name()) + .semanticRanker(chatRequest.context().overrides().semantic_ranker()) + .semanticCaptions(chatRequest.context().overrides().semantic_captions()) + .suggestFollowupQuestions( + chatRequest.context().overrides().suggest_followup_questions()) + .excludeCategory(chatRequest.context().overrides().exclude_category()) + .promptTemplate(chatRequest.context().overrides().prompt_template()) + .top(chatRequest.context().overrides().top()) + .build(); + + RAGApproach ragApproach = + ragApproachFactory.createApproach(chatRequest.approach(), RAGType.CHAT, ragOptions); ChatGPTConversation chatGPTConversation = convertToChatGPT(chatRequest.messages()); - StreamingResponseBody response = output -> { - try { - ragApproach.runStreaming(chatGPTConversation, ragOptions, output); - } finally { - output.flush(); - output.close(); - } - }; - - return ResponseEntity.ok() - .contentType(MediaType.APPLICATION_NDJSON) - .body(response); + StreamingResponseBody response = + output -> { + try { + ragApproach.runStreaming(chatGPTConversation, ragOptions, output); + } finally { + output.flush(); + output.close(); + } + }; + + return ResponseEntity.ok().contentType(MediaType.APPLICATION_NDJSON).body(response); } - @PostMapping( - value = "/api/chat", - produces = MediaType.APPLICATION_JSON_VALUE - ) + @PostMapping(value = "/api/chat", produces = MediaType.APPLICATION_JSON_VALUE) public ResponseEntity openAIAsk(@RequestBody ChatAppRequest chatRequest) { if (chatRequest.stream()) { - LOGGER.warn("Requested a content-type of application/json however also requested streaming. Please use a content-type of application/ndjson"); - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Requested a content-type of application/json however also requested streaming. Please use a content-type of application/ndjson"); + LOGGER.warn( + "Requested a content-type of application/json however also requested streaming." + + " Please use a content-type of application/ndjson"); + throw new ResponseStatusException( + HttpStatus.BAD_REQUEST, + "Requested a content-type of application/json however also requested streaming." + + " Please use a content-type of application/ndjson"); } LOGGER.info("Received request for chat api with approach[{}]", chatRequest.approach()); @@ -110,32 +115,39 @@ public ResponseEntity openAIAsk(@RequestBody ChatAppRequest chatRe return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null); } - var ragOptions = new RAGOptions.Builder() - .retrievialMode(chatRequest.context().overrides().retrieval_mode().name()) - .semanticRanker(chatRequest.context().overrides().semantic_ranker()) - .semanticCaptions(chatRequest.context().overrides().semantic_captions()) - .suggestFollowupQuestions(chatRequest.context().overrides().suggest_followup_questions()) - .excludeCategory(chatRequest.context().overrides().exclude_category()) - .promptTemplate(chatRequest.context().overrides().prompt_template()) - .top(chatRequest.context().overrides().top()) - .build(); - - RAGApproach ragApproach = ragApproachFactory.createApproach(chatRequest.approach(), RAGType.CHAT, ragOptions); - + var ragOptions = + new RAGOptions.Builder() + .retrievialMode(chatRequest.context().overrides().retrieval_mode().name()) + .semanticRanker(chatRequest.context().overrides().semantic_ranker()) + .semanticCaptions(chatRequest.context().overrides().semantic_captions()) + .suggestFollowupQuestions( + chatRequest.context().overrides().suggest_followup_questions()) + .excludeCategory(chatRequest.context().overrides().exclude_category()) + .promptTemplate(chatRequest.context().overrides().prompt_template()) + .top(chatRequest.context().overrides().top()) + .build(); + + RAGApproach ragApproach = + ragApproachFactory.createApproach(chatRequest.approach(), RAGType.CHAT, ragOptions); ChatGPTConversation chatGPTConversation = convertToChatGPT(chatRequest.messages()); - return ResponseEntity.ok(ChatResponse.buildChatResponse(ragApproach.run(chatGPTConversation, ragOptions))); - + return ResponseEntity.ok( + ChatResponse.buildChatResponse(ragApproach.run(chatGPTConversation, ragOptions))); } private ChatGPTConversation convertToChatGPT(List chatHistory) { return new ChatGPTConversation( chatHistory.stream() - .map(historyChat -> { - List chatGPTMessages = new ArrayList<>(); - chatGPTMessages.add(new ChatGPTMessage(ChatGPTMessage.ChatRole.fromString(historyChat.role()), historyChat.content())); - return chatGPTMessages; - }) + .map( + historyChat -> { + List chatGPTMessages = new ArrayList<>(); + chatGPTMessages.add( + new ChatGPTMessage( + ChatGPTMessage.ChatRole.fromString( + historyChat.role()), + historyChat.content())); + return chatGPTMessages; + }) .flatMap(Collection::stream) .toList()); } diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTConversation.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTConversation.java index ccb42d8..c57f057 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTConversation.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTConversation.java @@ -1,35 +1,37 @@ -package com.microsoft.openai.samples.rag.common; - -import com.azure.ai.openai.models.ChatMessage; - -import java.util.List; - -public class ChatGPTConversation { - - private List messages; - private Integer tokenCount = 0; - - public ChatGPTConversation(List messages) { - this.messages = messages; - } - - public List toOpenAIChatMessages() { - return this.messages.stream() - .map(message -> - { - ChatMessage chatMessage = new ChatMessage(com.azure.ai.openai.models.ChatRole.fromString(message.role().toString())); - chatMessage.setContent(message.content()); - return chatMessage; - }) - .toList(); - } - - public List getMessages() { - return messages; - } - - public void setMessages(List messages) { - this.messages = messages; - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.common; + +import com.azure.ai.openai.models.ChatMessage; +import java.util.List; + +public class ChatGPTConversation { + + private List messages; + private Integer tokenCount = 0; + + public ChatGPTConversation(List messages) { + this.messages = messages; + } + + public List toOpenAIChatMessages() { + return this.messages.stream() + .map( + message -> { + ChatMessage chatMessage = + new ChatMessage( + com.azure.ai.openai.models.ChatRole.fromString( + message.role().toString())); + chatMessage.setContent(message.content()); + return chatMessage; + }) + .toList(); + } + + public List getMessages() { + return messages; + } + + public void setMessages(List messages) { + this.messages = messages; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTMessage.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTMessage.java index 7733db2..dbd0045 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTMessage.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTMessage.java @@ -1,26 +1,24 @@ -package com.microsoft.openai.samples.rag.common; - -import com.azure.core.util.ExpandableStringEnum; - -import java.util.Collection; - -public record ChatGPTMessage(ChatGPTMessage.ChatRole role, - String content) { - - public static final class ChatRole extends ExpandableStringEnum { - public static final ChatGPTMessage.ChatRole SYSTEM = fromString("system"); - - public static final ChatGPTMessage.ChatRole ASSISTANT = fromString("assistant"); - - public static final ChatGPTMessage.ChatRole USER = fromString("user"); - - public static ChatGPTMessage.ChatRole fromString(String name) { - return fromString(name, ChatGPTMessage.ChatRole.class); - } - - public static Collection values() { - return values(ChatGPTMessage.ChatRole.class); - } - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.common; + +import com.azure.core.util.ExpandableStringEnum; +import java.util.Collection; + +public record ChatGPTMessage(ChatGPTMessage.ChatRole role, String content) { + + public static final class ChatRole extends ExpandableStringEnum { + public static final ChatGPTMessage.ChatRole SYSTEM = fromString("system"); + + public static final ChatGPTMessage.ChatRole ASSISTANT = fromString("assistant"); + + public static final ChatGPTMessage.ChatRole USER = fromString("user"); + + public static ChatGPTMessage.ChatRole fromString(String name) { + return fromString(name, ChatGPTMessage.ChatRole.class); + } + + public static Collection values() { + return values(ChatGPTMessage.ChatRole.class); + } + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTUtils.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTUtils.java index db775bf..916c582 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTUtils.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/common/ChatGPTUtils.java @@ -1,51 +1,51 @@ -package com.microsoft.openai.samples.rag.common; - -import com.azure.ai.openai.models.ChatCompletionsOptions; -import com.azure.ai.openai.models.ChatMessage; -import com.azure.ai.openai.models.ChatRole; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; - -public class ChatGPTUtils { - - public static ChatCompletionsOptions buildDefaultChatCompletionsOptions(List messages) { - ChatCompletionsOptions completionsOptions = new ChatCompletionsOptions(messages); - - completionsOptions.setMaxTokens(1024); - completionsOptions.setTemperature(0.1); - completionsOptions.setTopP(1.0); - //completionsOptions.setStop(new ArrayList<>(List.of("\n"))); - completionsOptions.setLogitBias(new HashMap<>()); - completionsOptions.setN(1); - completionsOptions.setStream(false); - completionsOptions.setUser("search-openai-demo-java"); - completionsOptions.setPresencePenalty(0.0); - completionsOptions.setFrequencyPenalty(0.0); - - return completionsOptions; - } - - public static String formatAsChatML(List messages) { - StringBuilder sb = new StringBuilder(); - messages.forEach(message -> { - if(message.getRole() == ChatRole.USER){ - sb.append("<|im_start|>user\n"); - } else if(message.getRole() == ChatRole.ASSISTANT) { - sb.append("<|im_start|>assistant\n"); - } else { - sb.append("<|im_start|>system\n"); - } - sb.append(message.getContent()).append("\n").append("|im_end|").append("\n"); - }); - return sb.toString(); - } - - public static String getLastUserQuestion(List messages){ - ChatGPTMessage message = messages.get(messages.size()-1); - if(message.role() != ChatGPTMessage.ChatRole.USER) - return message.content(); - return ""; - } -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.common; + +import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatMessage; +import com.azure.ai.openai.models.ChatRole; +import java.util.HashMap; +import java.util.List; + +public class ChatGPTUtils { + + public static ChatCompletionsOptions buildDefaultChatCompletionsOptions( + List messages) { + ChatCompletionsOptions completionsOptions = new ChatCompletionsOptions(messages); + + completionsOptions.setMaxTokens(1024); + completionsOptions.setTemperature(0.1); + completionsOptions.setTopP(1.0); + // completionsOptions.setStop(new ArrayList<>(List.of("\n"))); + completionsOptions.setLogitBias(new HashMap<>()); + completionsOptions.setN(1); + completionsOptions.setStream(false); + completionsOptions.setUser("search-openai-demo-java"); + completionsOptions.setPresencePenalty(0.0); + completionsOptions.setFrequencyPenalty(0.0); + + return completionsOptions; + } + + public static String formatAsChatML(List messages) { + StringBuilder sb = new StringBuilder(); + messages.forEach( + message -> { + if (message.getRole() == ChatRole.USER) { + sb.append("<|im_start|>user\n"); + } else if (message.getRole() == ChatRole.ASSISTANT) { + sb.append("<|im_start|>assistant\n"); + } else { + sb.append("<|im_start|>system\n"); + } + sb.append(message.getContent()).append("\n").append("|im_end|").append("\n"); + }); + return sb.toString(); + } + + public static String getLastUserQuestion(List messages) { + ChatGPTMessage message = messages.get(messages.size() - 1); + if (message.role() != ChatGPTMessage.ChatRole.USER) return message.content(); + return ""; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/AzureAuthenticationConfiguration.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/AzureAuthenticationConfiguration.java index 5d814d4..bc728b4 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/AzureAuthenticationConfiguration.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/AzureAuthenticationConfiguration.java @@ -1,25 +1,25 @@ -package com.microsoft.openai.samples.rag.config; - -import com.azure.core.credential.TokenCredential; -import com.azure.identity.AzureCliCredentialBuilder; -import com.azure.identity.ManagedIdentityCredentialBuilder; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Profile; - -@Configuration -public class AzureAuthenticationConfiguration { - - @Profile("dev") - @Bean - public TokenCredential localTokenCredential() { - return new AzureCliCredentialBuilder().build(); - } - - @Bean - @Profile("default") - public TokenCredential managedIdentityTokenCredential() { - return new ManagedIdentityCredentialBuilder().build(); - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.config; + +import com.azure.core.credential.TokenCredential; +import com.azure.identity.AzureCliCredentialBuilder; +import com.azure.identity.ManagedIdentityCredentialBuilder; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; + +@Configuration +public class AzureAuthenticationConfiguration { + + @Profile("dev") + @Bean + public TokenCredential localTokenCredential() { + return new AzureCliCredentialBuilder().build(); + } + + @Bean + @Profile("default") + public TokenCredential managedIdentityTokenCredential() { + return new ManagedIdentityCredentialBuilder().build(); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/CognitiveSearchConfiguration.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/CognitiveSearchConfiguration.java index 403dfc0..44a3abb 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/CognitiveSearchConfiguration.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/CognitiveSearchConfiguration.java @@ -1,83 +1,85 @@ -package com.microsoft.openai.samples.rag.config; - -import com.azure.core.credential.TokenCredential; -import com.azure.core.http.policy.HttpLogDetailLevel; -import com.azure.core.http.policy.HttpLogOptions; -import com.azure.search.documents.SearchAsyncClient; -import com.azure.search.documents.SearchClient; -import com.azure.search.documents.SearchClientBuilder; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -@Configuration -public class CognitiveSearchConfiguration { - - @Value("${cognitive.search.service}") String searchServiceName ; - @Value("${cognitive.search.index}") String indexName; - final TokenCredential tokenCredential; - - public CognitiveSearchConfiguration(TokenCredential tokenCredential) { - this.tokenCredential = tokenCredential; - } - - @Bean - @ConditionalOnProperty(name = "cognitive.tracing.enabled", havingValue = "true") - public SearchClient searchTracingEnabledClient() { - String endpoint = "https://%s.search.windows.net".formatted(searchServiceName); - - var httpLogOptions = new HttpLogOptions(); - httpLogOptions.setPrettyPrintBody(true); - httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); - - return new SearchClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .indexName(indexName) - .httpLogOptions(httpLogOptions) - .buildClient(); - - } - - @Bean - @ConditionalOnProperty(name = "cognitive.tracing.enabled", havingValue = "false") - public SearchClient searchDefaultClient() { - String endpoint = "https://%s.search.windows.net".formatted(searchServiceName); - return new SearchClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .indexName(indexName) - .buildClient(); - } - - @Bean - @ConditionalOnProperty(name = "cognitive.tracing.enabled", havingValue = "true") - public SearchAsyncClient asyncSearchTracingEnabledClient() { - String endpoint = "https://%s.search.windows.net".formatted(searchServiceName); - - var httpLogOptions = new HttpLogOptions(); - httpLogOptions.setPrettyPrintBody(true); - httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); - - return new SearchClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .indexName(indexName) - .httpLogOptions(httpLogOptions) - .buildAsyncClient(); - - } - - @Bean - @ConditionalOnProperty(name = "cognitive.tracing.enabled", havingValue = "false") - public SearchAsyncClient asyncSearchDefaultClient() { - String endpoint = "https://%s.search.windows.net".formatted(searchServiceName); - return new SearchClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .indexName(indexName) - .buildAsyncClient(); - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.config; + +import com.azure.core.credential.TokenCredential; +import com.azure.core.http.policy.HttpLogDetailLevel; +import com.azure.core.http.policy.HttpLogOptions; +import com.azure.search.documents.SearchAsyncClient; +import com.azure.search.documents.SearchClient; +import com.azure.search.documents.SearchClientBuilder; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class CognitiveSearchConfiguration { + + @Value("${cognitive.search.service}") + String searchServiceName; + + @Value("${cognitive.search.index}") + String indexName; + + final TokenCredential tokenCredential; + + public CognitiveSearchConfiguration(TokenCredential tokenCredential) { + this.tokenCredential = tokenCredential; + } + + @Bean + @ConditionalOnProperty(name = "cognitive.tracing.enabled", havingValue = "true") + public SearchClient searchTracingEnabledClient() { + String endpoint = "https://%s.search.windows.net".formatted(searchServiceName); + + var httpLogOptions = new HttpLogOptions(); + httpLogOptions.setPrettyPrintBody(true); + httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); + + return new SearchClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .indexName(indexName) + .httpLogOptions(httpLogOptions) + .buildClient(); + } + + @Bean + @ConditionalOnProperty(name = "cognitive.tracing.enabled", havingValue = "false") + public SearchClient searchDefaultClient() { + String endpoint = "https://%s.search.windows.net".formatted(searchServiceName); + return new SearchClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .indexName(indexName) + .buildClient(); + } + + @Bean + @ConditionalOnProperty(name = "cognitive.tracing.enabled", havingValue = "true") + public SearchAsyncClient asyncSearchTracingEnabledClient() { + String endpoint = "https://%s.search.windows.net".formatted(searchServiceName); + + var httpLogOptions = new HttpLogOptions(); + httpLogOptions.setPrettyPrintBody(true); + httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); + + return new SearchClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .indexName(indexName) + .httpLogOptions(httpLogOptions) + .buildAsyncClient(); + } + + @Bean + @ConditionalOnProperty(name = "cognitive.tracing.enabled", havingValue = "false") + public SearchAsyncClient asyncSearchDefaultClient() { + String endpoint = "https://%s.search.windows.net".formatted(searchServiceName); + return new SearchClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .indexName(indexName) + .buildAsyncClient(); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/OpenAIConfiguration.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/OpenAIConfiguration.java index 9b3ed40..b8df022 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/OpenAIConfiguration.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/config/OpenAIConfiguration.java @@ -1,79 +1,78 @@ -package com.microsoft.openai.samples.rag.config; - -import com.azure.ai.openai.OpenAIAsyncClient; -import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.core.credential.TokenCredential; -import com.azure.core.http.policy.HttpLogDetailLevel; -import com.azure.core.http.policy.HttpLogOptions; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -@Configuration -public class OpenAIConfiguration { - - @Value("${openai.service}") - String openAIServiceName; - final TokenCredential tokenCredential; - - public OpenAIConfiguration(TokenCredential tokenCredential) { - this.tokenCredential = tokenCredential; - } - - @Bean - @ConditionalOnProperty(name = "openai.tracing.enabled", havingValue = "true") - public OpenAIClient openAItracingEnabledClient() { - String endpoint = "https://%s.openai.azure.com".formatted(openAIServiceName); - - var httpLogOptions = new HttpLogOptions(); - //httpLogOptions.setPrettyPrintBody(true); - httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); - - return new OpenAIClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .httpLogOptions(httpLogOptions) - .buildClient(); - - } - - @Bean - @ConditionalOnProperty(name = "openai.tracing.enabled", havingValue = "false") - public OpenAIClient openAIDefaultClient() { - String endpoint = "https://%s.openai.azure.com".formatted(openAIServiceName); - return new OpenAIClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .buildClient(); - } - - @Bean - @ConditionalOnProperty(name = "openai.tracing.enabled", havingValue = "true") - public OpenAIAsyncClient tracingEnabledAsyncClient() { - String endpoint = "https://%s.openai.azure.com".formatted(openAIServiceName); - - var httpLogOptions = new HttpLogOptions(); - httpLogOptions.setPrettyPrintBody(true); - httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); - - return new OpenAIClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .httpLogOptions(httpLogOptions) - .buildAsyncClient(); - - } - - @Bean - @ConditionalOnProperty(name = "openai.tracing.enabled", havingValue = "false") - public OpenAIAsyncClient defaultAsyncClient() { - String endpoint = "https://%s.openai.azure.com".formatted(openAIServiceName); - return new OpenAIClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .buildAsyncClient(); - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.config; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.TokenCredential; +import com.azure.core.http.policy.HttpLogDetailLevel; +import com.azure.core.http.policy.HttpLogOptions; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class OpenAIConfiguration { + + @Value("${openai.service}") + String openAIServiceName; + + final TokenCredential tokenCredential; + + public OpenAIConfiguration(TokenCredential tokenCredential) { + this.tokenCredential = tokenCredential; + } + + @Bean + @ConditionalOnProperty(name = "openai.tracing.enabled", havingValue = "true") + public OpenAIClient openAItracingEnabledClient() { + String endpoint = "https://%s.openai.azure.com".formatted(openAIServiceName); + + var httpLogOptions = new HttpLogOptions(); + // httpLogOptions.setPrettyPrintBody(true); + httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); + + return new OpenAIClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .httpLogOptions(httpLogOptions) + .buildClient(); + } + + @Bean + @ConditionalOnProperty(name = "openai.tracing.enabled", havingValue = "false") + public OpenAIClient openAIDefaultClient() { + String endpoint = "https://%s.openai.azure.com".formatted(openAIServiceName); + return new OpenAIClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .buildClient(); + } + + @Bean + @ConditionalOnProperty(name = "openai.tracing.enabled", havingValue = "true") + public OpenAIAsyncClient tracingEnabledAsyncClient() { + String endpoint = "https://%s.openai.azure.com".formatted(openAIServiceName); + + var httpLogOptions = new HttpLogOptions(); + httpLogOptions.setPrettyPrintBody(true); + httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); + + return new OpenAIClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .httpLogOptions(httpLogOptions) + .buildAsyncClient(); + } + + @Bean + @ConditionalOnProperty(name = "openai.tracing.enabled", havingValue = "false") + public OpenAIAsyncClient defaultAsyncClient() { + String endpoint = "https://%s.openai.azure.com".formatted(openAIServiceName); + return new OpenAIClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .buildAsyncClient(); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/content/controller/ContentController.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/content/controller/ContentController.java index 22b1b7f..ea1a3ca 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/content/controller/ContentController.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/content/controller/ContentController.java @@ -1,6 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.content.controller; import com.microsoft.openai.samples.rag.proxy.BlobStorageProxy; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URLConnection; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.io.InputStreamResource; @@ -13,11 +18,6 @@ import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RestController; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.net.URLConnection; - @RestController public class ContentController { @@ -51,9 +51,8 @@ public ResponseEntity getContent(@PathVariable String fileN } return ResponseEntity.ok() - .header("Content-Disposition", "inline; filename=%s".formatted(fileName)) - .contentType(contentType) - .body(new InputStreamResource(fileInputStream)); + .header("Content-Disposition", "inline; filename=%s".formatted(fileName)) + .contentType(contentType) + .body(new InputStreamResource(fileInputStream)); } - } diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequest.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequest.java index d444865..b716fa1 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequest.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequest.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.controller; import java.util.List; @@ -6,6 +7,4 @@ public record ChatAppRequest( List messages, ChatAppRequestContext context, boolean stream, - String approach -) { -} + String approach) {} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequestContext.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequestContext.java index c5a552c..5c778ef 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequestContext.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequestContext.java @@ -1,4 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.controller; -public record ChatAppRequestContext(ChatAppRequestOverrides overrides) { -} \ No newline at end of file +public record ChatAppRequestContext(ChatAppRequestOverrides overrides) {} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequestOverrides.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequestOverrides.java index b05b0e3..39341f3 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequestOverrides.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatAppRequestOverrides.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.controller; import com.microsoft.openai.samples.rag.approaches.RetrievalMode; @@ -15,6 +16,4 @@ public record ChatAppRequestOverrides( boolean suggest_followup_questions, boolean use_oid_security_filter, boolean use_groups_security_filter, - String semantic_kernel_mode -) { -} \ No newline at end of file + String semantic_kernel_mode) {} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatResponse.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatResponse.java index 8ee0689..b30604f 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatResponse.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ChatResponse.java @@ -1,62 +1,53 @@ -package com.microsoft.openai.samples.rag.controller; - -import com.microsoft.openai.samples.rag.approaches.RAGResponse; -import com.microsoft.openai.samples.rag.common.ChatGPTMessage; - -import java.util.Collections; -import java.util.List; - -public record ChatResponse(List choices) { - - public static ChatResponse buildChatResponse(RAGResponse ragResponse) { - List dataPoints = Collections.emptyList(); - - if (ragResponse.getSources() != null) { - dataPoints = ragResponse.getSources().stream() - .map(source -> source.getSourceName() + ": " + source.getSourceContent()) - .toList(); - } - - String thoughts = "Question:
" + ragResponse.getQuestion() + "

Prompt:
" + ragResponse.getPrompt().replace("\n", "
"); - - return new ChatResponse( - List.of( - new ResponseChoice( - 0, - new ResponseMessage( - ragResponse.getAnswer(), - ChatGPTMessage.ChatRole.ASSISTANT.toString() - ), - new ResponseContext( - thoughts, - dataPoints - ), - new ResponseMessage( - ragResponse.getAnswer(), - ChatGPTMessage.ChatRole.ASSISTANT.toString() - ) - ) - ) - ); - } - - public static ChatResponse buildChatDeltaResponse(Integer index, RAGResponse ragResponse) { - return new ChatResponse( - List.of( - new ResponseChoice( - index, - new ResponseMessage( - ragResponse.getAnswer(), - "ASSISTANT" - ), - null, - new ResponseMessage( - ragResponse.getAnswer(), - "ASSISTANT" - ) - ) - ) - ); - } - -} \ No newline at end of file +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.controller; + +import com.microsoft.openai.samples.rag.approaches.RAGResponse; +import com.microsoft.openai.samples.rag.common.ChatGPTMessage; +import java.util.Collections; +import java.util.List; + +public record ChatResponse(List choices) { + + public static ChatResponse buildChatResponse(RAGResponse ragResponse) { + List dataPoints = Collections.emptyList(); + + if (ragResponse.getSources() != null) { + dataPoints = + ragResponse.getSources().stream() + .map( + source -> + source.getSourceName() + + ": " + + source.getSourceContent()) + .toList(); + } + + String thoughts = + "Question:
" + + ragResponse.getQuestion() + + "

Prompt:
" + + ragResponse.getPrompt().replace("\n", "
"); + + return new ChatResponse( + List.of( + new ResponseChoice( + 0, + new ResponseMessage( + ragResponse.getAnswer(), + ChatGPTMessage.ChatRole.ASSISTANT.toString()), + new ResponseContext(thoughts, dataPoints), + new ResponseMessage( + ragResponse.getAnswer(), + ChatGPTMessage.ChatRole.ASSISTANT.toString())))); + } + + public static ChatResponse buildChatDeltaResponse(Integer index, RAGResponse ragResponse) { + return new ChatResponse( + List.of( + new ResponseChoice( + index, + new ResponseMessage(ragResponse.getAnswer(), "ASSISTANT"), + null, + new ResponseMessage(ragResponse.getAnswer(), "ASSISTANT")))); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseChoice.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseChoice.java index f622d3d..705f264 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseChoice.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseChoice.java @@ -1,4 +1,5 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.controller; -public record ResponseChoice(int index, ResponseMessage message, ResponseContext context, ResponseMessage delta) { -} +public record ResponseChoice( + int index, ResponseMessage message, ResponseContext context, ResponseMessage delta) {} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseContext.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseContext.java index 1c55106..0166789 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseContext.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseContext.java @@ -1,6 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.controller; import java.util.List; -public record ResponseContext(String thoughts, List data_points) { -} \ No newline at end of file +public record ResponseContext(String thoughts, List data_points) {} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseMessage.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseMessage.java index aeb918b..7dd5f00 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseMessage.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/ResponseMessage.java @@ -1,4 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.controller; -public record ResponseMessage(String content, String role) { -} \ No newline at end of file +public record ResponseMessage(String content, String role) {} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/auth/AuthSetup.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/auth/AuthSetup.java index fada2dd..53a87dd 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/auth/AuthSetup.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/controller/auth/AuthSetup.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.controller.auth; import org.springframework.web.bind.annotation.GetMapping; @@ -12,6 +13,7 @@ public String authSetup() { { "useLogin": false } - """.stripIndent(); + """ + .stripIndent(); } } diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/BlobStorageProxy.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/BlobStorageProxy.java index beacfe6..591d273 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/BlobStorageProxy.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/BlobStorageProxy.java @@ -1,48 +1,47 @@ -package com.microsoft.openai.samples.rag.proxy; - -import com.azure.core.credential.TokenCredential; -import com.azure.storage.blob.BlobContainerClient; -import com.azure.storage.blob.BlobContainerClientBuilder; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; - -/** - * This class is a proxy to the Blob storage API. - * It is responsible for: - * - calling the API - * - handling errors and retry strategy - * - add monitoring points - * - add circuit breaker with exponential backoff - */ -@Component -public class BlobStorageProxy { - - private final BlobContainerClient client; - - public BlobStorageProxy(@Value("${storage-account.service}") String storageAccountServiceName, - @Value("${blob.container.name}") String containerName, - TokenCredential tokenCredential) { - - String endpoint = "https://%s.blob.core.windows.net".formatted(storageAccountServiceName); - this.client = new BlobContainerClientBuilder() - .endpoint(endpoint) - .credential(tokenCredential) - .containerName(containerName) - .buildClient(); - } - - public byte[] getFileAsBytes(String fileName) throws IOException { - var blobClient = client.getBlobClient(fileName); - int dataSize = (int) blobClient.getProperties().getBlobSize(); - - // There is no need to close ByteArrayOutputStream. https://docs.oracle.com/javase/8/docs/api/java/io/ByteArrayOutputStream.html - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(dataSize); - blobClient.downloadStream(outputStream); - - return outputStream.toByteArray(); - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.proxy; + +import com.azure.core.credential.TokenCredential; +import com.azure.storage.blob.BlobContainerClient; +import com.azure.storage.blob.BlobContainerClientBuilder; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +/** + * This class is a proxy to the Blob storage API. It is responsible for: - calling the API - + * handling errors and retry strategy - add monitoring points - add circuit breaker with exponential + * backoff + */ +@Component +public class BlobStorageProxy { + + private final BlobContainerClient client; + + public BlobStorageProxy( + @Value("${storage-account.service}") String storageAccountServiceName, + @Value("${blob.container.name}") String containerName, + TokenCredential tokenCredential) { + + String endpoint = "https://%s.blob.core.windows.net".formatted(storageAccountServiceName); + this.client = + new BlobContainerClientBuilder() + .endpoint(endpoint) + .credential(tokenCredential) + .containerName(containerName) + .buildClient(); + } + + public byte[] getFileAsBytes(String fileName) throws IOException { + var blobClient = client.getBlobClient(fileName); + int dataSize = (int) blobClient.getProperties().getBlobSize(); + + // There is no need to close ByteArrayOutputStream. + // https://docs.oracle.com/javase/8/docs/api/java/io/ByteArrayOutputStream.html + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(dataSize); + blobClient.downloadStream(outputStream); + + return outputStream.toByteArray(); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/CognitiveSearchProxy.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/CognitiveSearchProxy.java index b28ee95..cc64336 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/CognitiveSearchProxy.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/CognitiveSearchProxy.java @@ -1,30 +1,28 @@ -package com.microsoft.openai.samples.rag.proxy; - -import com.azure.core.util.Context; -import com.azure.search.documents.SearchClient; -import com.azure.search.documents.models.SearchOptions; -import com.azure.search.documents.util.SearchPagedIterable; -import org.springframework.stereotype.Component; - -/** - * This class is a proxy to the Cognitive Search API. - * It is responsible for: - * - calling the OpenAI API - * - handling errors and retry strategy - * - add monitoring points - * - add circuit breaker with exponential backoff - */ -@Component -public class CognitiveSearchProxy { - - private final SearchClient client; - - public CognitiveSearchProxy(SearchClient searchClient) { - this.client= searchClient; - } - - public SearchPagedIterable search(String searchText, SearchOptions searchOptions, Context context){ - return client.search(searchText,searchOptions,context); - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.proxy; + +import com.azure.core.util.Context; +import com.azure.search.documents.SearchClient; +import com.azure.search.documents.models.SearchOptions; +import com.azure.search.documents.util.SearchPagedIterable; +import org.springframework.stereotype.Component; + +/** + * This class is a proxy to the Cognitive Search API. It is responsible for: - calling the OpenAI + * API - handling errors and retry strategy - add monitoring points - add circuit breaker with + * exponential backoff + */ +@Component +public class CognitiveSearchProxy { + + private final SearchClient client; + + public CognitiveSearchProxy(SearchClient searchClient) { + this.client = searchClient; + } + + public SearchPagedIterable search( + String searchText, SearchOptions searchOptions, Context context) { + return client.search(searchText, searchOptions, context); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/OpenAIProxy.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/OpenAIProxy.java index b6e53be..a8a1360 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/OpenAIProxy.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/proxy/OpenAIProxy.java @@ -1,90 +1,103 @@ -package com.microsoft.openai.samples.rag.proxy; - -import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.models.*; -import com.azure.core.exception.HttpResponseException; -import com.azure.core.util.IterableStream; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; -import org.springframework.web.server.ResponseStatusException; - -import java.util.List; - -/** - * This class is a proxy to the OpenAI API to simplify cross-cutting concerns management (security, load balancing, monitoring, resiliency). - * It is responsible for: - * - calling the OpenAI API - * - handling errors and retry strategy - * - load balance requests across open AI instances - * - add monitoring points - * - add circuit breaker with exponential backoff - *

- * It also makes unit testing easy using mockito to provide mock implementation for this bean. - */ -@Component -public class OpenAIProxy { - - private final OpenAIClient client; - @Value("${openai.chatgpt.deployment}") - private String gptChatDeploymentModelId; - - @Value("${openai.embedding.deployment}") - private String embeddingDeploymentModelId; - - - public OpenAIProxy(OpenAIClient client) { - this.client = client; - } - - public Completions getCompletions(CompletionsOptions completionsOptions) { - Completions completions; - try { - completions = client.getCompletions(this.gptChatDeploymentModelId, completionsOptions); - } catch (HttpResponseException e) { - throw new ResponseStatusException(e.getResponse().getStatusCode(), "Error calling OpenAI API:" + e.getValue(), e); - } - return completions; - } - - public Completions getCompletions(String prompt) { - - Completions completions; - try { - completions = client.getCompletions(this.gptChatDeploymentModelId, prompt); - } catch (HttpResponseException e) { - throw new ResponseStatusException(e.getResponse().getStatusCode(), "Error calling OpenAI API:" + e.getMessage(), e); - } - return completions; - } - - public ChatCompletions getChatCompletions(ChatCompletionsOptions chatCompletionsOptions) { - ChatCompletions chatCompletions; - try { - chatCompletions = client.getChatCompletions(this.gptChatDeploymentModelId, chatCompletionsOptions); - } catch (HttpResponseException e) { - throw new ResponseStatusException(e.getResponse().getStatusCode(), "Error calling OpenAI API:" + e.getMessage(), e); - } - return chatCompletions; - } - - public IterableStream getChatCompletionsStream(ChatCompletionsOptions chatCompletionsOptions) { - try { - return client.getChatCompletionsStream(this.gptChatDeploymentModelId, chatCompletionsOptions); - } catch (HttpResponseException e) { - throw new ResponseStatusException(e.getResponse().getStatusCode(), "Error calling OpenAI API:" + e.getMessage(), e); - } - } - - public Embeddings getEmbeddings(List texts){ - Embeddings embeddings; - try { - EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(texts); - embeddingsOptions.setUser("search-openai-demo-java"); - embeddings = client.getEmbeddings(this.embeddingDeploymentModelId, embeddingsOptions); - } catch (HttpResponseException e) { - throw new ResponseStatusException(e.getResponse().getStatusCode(), "Error calling OpenAI API:" + e.getMessage(), e); - } - return embeddings; - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.proxy; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.models.*; +import com.azure.core.exception.HttpResponseException; +import com.azure.core.util.IterableStream; +import java.util.List; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import org.springframework.web.server.ResponseStatusException; + +/** + * This class is a proxy to the OpenAI API to simplify cross-cutting concerns management (security, + * load balancing, monitoring, resiliency). It is responsible for: - calling the OpenAI API - + * handling errors and retry strategy - load balance requests across open AI instances - add + * monitoring points - add circuit breaker with exponential backoff + * + *

It also makes unit testing easy using mockito to provide mock implementation for this bean. + */ +@Component +public class OpenAIProxy { + + private final OpenAIClient client; + + @Value("${openai.chatgpt.deployment}") + private String gptChatDeploymentModelId; + + @Value("${openai.embedding.deployment}") + private String embeddingDeploymentModelId; + + public OpenAIProxy(OpenAIClient client) { + this.client = client; + } + + public Completions getCompletions(CompletionsOptions completionsOptions) { + Completions completions; + try { + completions = client.getCompletions(this.gptChatDeploymentModelId, completionsOptions); + } catch (HttpResponseException e) { + throw new ResponseStatusException( + e.getResponse().getStatusCode(), "Error calling OpenAI API:" + e.getValue(), e); + } + return completions; + } + + public Completions getCompletions(String prompt) { + + Completions completions; + try { + completions = client.getCompletions(this.gptChatDeploymentModelId, prompt); + } catch (HttpResponseException e) { + throw new ResponseStatusException( + e.getResponse().getStatusCode(), + "Error calling OpenAI API:" + e.getMessage(), + e); + } + return completions; + } + + public ChatCompletions getChatCompletions(ChatCompletionsOptions chatCompletionsOptions) { + ChatCompletions chatCompletions; + try { + chatCompletions = + client.getChatCompletions( + this.gptChatDeploymentModelId, chatCompletionsOptions); + } catch (HttpResponseException e) { + throw new ResponseStatusException( + e.getResponse().getStatusCode(), + "Error calling OpenAI API:" + e.getMessage(), + e); + } + return chatCompletions; + } + + public IterableStream getChatCompletionsStream( + ChatCompletionsOptions chatCompletionsOptions) { + try { + return client.getChatCompletionsStream( + this.gptChatDeploymentModelId, chatCompletionsOptions); + } catch (HttpResponseException e) { + throw new ResponseStatusException( + e.getResponse().getStatusCode(), + "Error calling OpenAI API:" + e.getMessage(), + e); + } + } + + public Embeddings getEmbeddings(List texts) { + Embeddings embeddings; + try { + EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(texts); + embeddingsOptions.setUser("search-openai-demo-java"); + embeddings = client.getEmbeddings(this.embeddingDeploymentModelId, embeddingsOptions); + } catch (HttpResponseException e) { + throw new ResponseStatusException( + e.getResponse().getStatusCode(), + "Error calling OpenAI API:" + e.getMessage(), + e); + } + return embeddings; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/CognitiveSearchRetriever.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/CognitiveSearchRetriever.java index 4e41390..e4a8dab 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/CognitiveSearchRetriever.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/CognitiveSearchRetriever.java @@ -1,161 +1,191 @@ -package com.microsoft.openai.samples.rag.retrieval; - -import com.azure.ai.openai.models.ChatCompletions; -import com.azure.ai.openai.models.Embeddings; -import com.azure.core.util.Context; -import com.azure.search.documents.SearchDocument; -import com.azure.search.documents.models.*; -import com.azure.search.documents.util.SearchPagedIterable; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import com.microsoft.openai.samples.rag.approaches.RetrievalMode; -import com.microsoft.openai.samples.rag.common.ChatGPTConversation; -import com.microsoft.openai.samples.rag.common.ChatGPTUtils; -import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; -import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.stereotype.Component; - -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; - -/** - * Cognitive Search retriever implementation that uses the Cognitive Search API to retrieve documents from the search - * index. - * If retrieval mode is set to vectors or hybrid, it will use OpenAI embedding API to convert the user's query text to an embedding vector - * The hybrid search is specific to cognitive search feature which fuses the best of text search and vector search. - */ -@Component -public class CognitiveSearchRetriever implements Retriever{ - private static final Logger LOGGER = LoggerFactory.getLogger(CognitiveSearchRetriever.class); - private final CognitiveSearchProxy cognitiveSearchProxy; - private final OpenAIProxy openAIProxy; - - public CognitiveSearchRetriever(CognitiveSearchProxy cognitiveSearchProxy, OpenAIProxy openAIProxy){ - this.cognitiveSearchProxy = cognitiveSearchProxy; - this.openAIProxy = openAIProxy; - } - - /** - * - * @param question - * @param ragOptions - * @return the top documents retrieved from the search index based on the user's query text - */ - @Override - public List retrieveFromQuestion(String question, RAGOptions ragOptions) { - // step 1. Convert the user's query text to an embedding - SearchOptions searchOptions = new SearchOptions(); - String searchText = null; - - if(ragOptions.getRetrievalMode() == RetrievalMode.vectors || ragOptions.getRetrievalMode() == RetrievalMode.hybrid) { - LOGGER.info("Retrieval mode is set to {}. Retrieving vectors for question [{}]",ragOptions.getRetrievalMode(),question); - - Embeddings response = openAIProxy.getEmbeddings(List.of(question)); - var questionVector = response.getData().get(0).getEmbedding().stream().map(Double::floatValue).toList(); - if (ragOptions.getRetrievalMode() == RetrievalMode.vectors) { - setSearchOptionsForVector(ragOptions, questionVector,searchOptions); - } else { - searchText = question; - setSearchOptionsForHybrid(ragOptions, questionVector,searchOptions); - } - } else { - searchText = question; - setSearchOptions(ragOptions,searchOptions); - } - - SearchPagedIterable searchResults = cognitiveSearchProxy.search(searchText,searchOptions, Context.NONE); - return buildSourcesFromSearchResults(ragOptions, searchResults); - - } - - /** - * - * @param conversation - * @param ragOptions - * @return facts retrieved from the search index based on GPT optimized search keywords extracted from the chat history - */ - @Override - public List retrieveFromConversation(ChatGPTConversation conversation, RAGOptions ragOptions) { - - // STEP 1: Generate an optimized keyword search query based on the chat history and the last question - var extractKeywordsChatTemplate = new ExtractKeywordsChatTemplate(conversation); - var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(extractKeywordsChatTemplate.getMessages()); - ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions); - - var searchKeywords = chatCompletions.getChoices().get(0).getMessage().getContent(); - LOGGER.info("Search Keywords extracted by Open AI [{}]",searchKeywords); - - // STEP 2: Retrieve relevant documents from the search index with the GPT optimized search keywords - return retrieveFromQuestion(searchKeywords, ragOptions); - } - - private List buildSourcesFromSearchResults(RAGOptions options, SearchPagedIterable searchResults) { - List sources = new ArrayList<>(); - - searchResults.iterator().forEachRemaining(result -> - { - var searchDocument = result.getDocument(SearchDocument.class); - - /* - If captions is enabled the content source is taken from the captions generated by the semantic ranker. - Captions are appended sequentially and separated by a dot. - */ - if(options.isSemanticCaptions()) { - StringBuilder sourcesContentBuffer = new StringBuilder(); - - result.getCaptions().forEach(caption -> sourcesContentBuffer.append(caption.getText()).append(".")); - - sources.add(new ContentSource((String)searchDocument.get("sourcepage"), sourcesContentBuffer.toString())); - } else { - //If captions is disabled the content source is taken from the cognitive search index field "content" - sources.add(new ContentSource((String) searchDocument.get("sourcepage"), (String) searchDocument.get("content"))); - } - }); - - return sources; - } - - - private void setSearchOptionsForHybrid(RAGOptions ragOptions, List questionVector,SearchOptions searchOptions) { - setSearchOptions(ragOptions,searchOptions); - setSearchOptionsForVector(ragOptions,questionVector,searchOptions); - } - private void setSearchOptionsForVector(RAGOptions options,List questionVector,SearchOptions searchOptions) { - - Optional.ofNullable(options.getTop()).ifPresentOrElse( - searchOptions::setTop, - () -> searchOptions.setTop(3)); - - searchOptions.setVectors( new SearchQueryVector() - .setValue(questionVector) - .setKNearestNeighborsCount(options.getTop()) - .setFields("embedding")); - - } - - private void setSearchOptions(RAGOptions options,SearchOptions searchOptions) { - - Optional.ofNullable(options.getTop()).ifPresentOrElse( - searchOptions::setTop, - () -> searchOptions.setTop(3)); - Optional.ofNullable(options.getExcludeCategory()) - .ifPresentOrElse( - value -> searchOptions.setFilter("category ne '%s'".formatted(value.replace("'", "''"))), - () -> searchOptions.setFilter(null)); - - Optional.ofNullable(options.isSemanticRanker()).ifPresent(isSemanticRanker -> { - if (isSemanticRanker) { - searchOptions.setQueryType(QueryType.SEMANTIC); - searchOptions.setQueryLanguage(QueryLanguage.EN_US); - searchOptions.setSpeller(QuerySpellerType.LEXICON); - searchOptions.setSemanticConfigurationName("default"); - searchOptions.setQueryCaption(QueryCaptionType.EXTRACTIVE); - searchOptions.setQueryCaptionHighlightEnabled(false); - } - }); - - } -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.retrieval; + +import com.azure.ai.openai.models.ChatCompletions; +import com.azure.ai.openai.models.Embeddings; +import com.azure.core.util.Context; +import com.azure.search.documents.SearchDocument; +import com.azure.search.documents.models.*; +import com.azure.search.documents.util.SearchPagedIterable; +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import com.microsoft.openai.samples.rag.approaches.RetrievalMode; +import com.microsoft.openai.samples.rag.common.ChatGPTConversation; +import com.microsoft.openai.samples.rag.common.ChatGPTUtils; +import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; +import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +/** + * Cognitive Search retriever implementation that uses the Cognitive Search API to retrieve + * documents from the search index. If retrieval mode is set to vectors or hybrid, it will use + * OpenAI embedding API to convert the user's query text to an embedding vector The hybrid search is + * specific to cognitive search feature which fuses the best of text search and vector search. + */ +@Component +public class CognitiveSearchRetriever implements Retriever { + private static final Logger LOGGER = LoggerFactory.getLogger(CognitiveSearchRetriever.class); + private final CognitiveSearchProxy cognitiveSearchProxy; + private final OpenAIProxy openAIProxy; + + public CognitiveSearchRetriever( + CognitiveSearchProxy cognitiveSearchProxy, OpenAIProxy openAIProxy) { + this.cognitiveSearchProxy = cognitiveSearchProxy; + this.openAIProxy = openAIProxy; + } + + /** + * @param question + * @param ragOptions + * @return the top documents retrieved from the search index based on the user's query text + */ + @Override + public List retrieveFromQuestion(String question, RAGOptions ragOptions) { + // step 1. Convert the user's query text to an embedding + SearchOptions searchOptions = new SearchOptions(); + String searchText = null; + + if (ragOptions.getRetrievalMode() == RetrievalMode.vectors + || ragOptions.getRetrievalMode() == RetrievalMode.hybrid) { + LOGGER.info( + "Retrieval mode is set to {}. Retrieving vectors for question [{}]", + ragOptions.getRetrievalMode(), + question); + + Embeddings response = openAIProxy.getEmbeddings(List.of(question)); + var questionVector = + response.getData().get(0).getEmbedding().stream() + .map(Double::floatValue) + .toList(); + if (ragOptions.getRetrievalMode() == RetrievalMode.vectors) { + setSearchOptionsForVector(ragOptions, questionVector, searchOptions); + } else { + searchText = question; + setSearchOptionsForHybrid(ragOptions, questionVector, searchOptions); + } + } else { + searchText = question; + setSearchOptions(ragOptions, searchOptions); + } + + SearchPagedIterable searchResults = + cognitiveSearchProxy.search(searchText, searchOptions, Context.NONE); + return buildSourcesFromSearchResults(ragOptions, searchResults); + } + + /** + * @param conversation + * @param ragOptions + * @return facts retrieved from the search index based on GPT optimized search keywords + * extracted from the chat history + */ + @Override + public List retrieveFromConversation( + ChatGPTConversation conversation, RAGOptions ragOptions) { + + // STEP 1: Generate an optimized keyword search query based on the chat history and the last + // question + var extractKeywordsChatTemplate = new ExtractKeywordsChatTemplate(conversation); + var chatCompletionsOptions = + ChatGPTUtils.buildDefaultChatCompletionsOptions( + extractKeywordsChatTemplate.getMessages()); + ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions); + + var searchKeywords = chatCompletions.getChoices().get(0).getMessage().getContent(); + LOGGER.info("Search Keywords extracted by Open AI [{}]", searchKeywords); + + // STEP 2: Retrieve relevant documents from the search index with the GPT optimized search + // keywords + return retrieveFromQuestion(searchKeywords, ragOptions); + } + + private List buildSourcesFromSearchResults( + RAGOptions options, SearchPagedIterable searchResults) { + List sources = new ArrayList<>(); + + searchResults + .iterator() + .forEachRemaining( + result -> { + var searchDocument = result.getDocument(SearchDocument.class); + + /* + If captions is enabled the content source is taken from the captions generated by the semantic ranker. + Captions are appended sequentially and separated by a dot. + */ + if (options.isSemanticCaptions()) { + StringBuilder sourcesContentBuffer = new StringBuilder(); + + result.getCaptions() + .forEach( + caption -> + sourcesContentBuffer + .append(caption.getText()) + .append(".")); + + sources.add( + new ContentSource( + (String) searchDocument.get("sourcepage"), + sourcesContentBuffer.toString())); + } else { + // If captions is disabled the content source is taken from the + // cognitive search index field "content" + sources.add( + new ContentSource( + (String) searchDocument.get("sourcepage"), + (String) searchDocument.get("content"))); + } + }); + + return sources; + } + + private void setSearchOptionsForHybrid( + RAGOptions ragOptions, List questionVector, SearchOptions searchOptions) { + setSearchOptions(ragOptions, searchOptions); + setSearchOptionsForVector(ragOptions, questionVector, searchOptions); + } + + private void setSearchOptionsForVector( + RAGOptions options, List questionVector, SearchOptions searchOptions) { + + Optional.ofNullable(options.getTop()) + .ifPresentOrElse(searchOptions::setTop, () -> searchOptions.setTop(3)); + + searchOptions.setVectors( + new SearchQueryVector() + .setValue(questionVector) + .setKNearestNeighborsCount(options.getTop()) + .setFields("embedding")); + } + + private void setSearchOptions(RAGOptions options, SearchOptions searchOptions) { + + Optional.ofNullable(options.getTop()) + .ifPresentOrElse(searchOptions::setTop, () -> searchOptions.setTop(3)); + Optional.ofNullable(options.getExcludeCategory()) + .ifPresentOrElse( + value -> + searchOptions.setFilter( + "category ne '%s'".formatted(value.replace("'", "''"))), + () -> searchOptions.setFilter(null)); + + Optional.ofNullable(options.isSemanticRanker()) + .ifPresent( + isSemanticRanker -> { + if (isSemanticRanker) { + searchOptions.setQueryType(QueryType.SEMANTIC); + searchOptions.setQueryLanguage(QueryLanguage.EN_US); + searchOptions.setSpeller(QuerySpellerType.LEXICON); + searchOptions.setSemanticConfigurationName("default"); + searchOptions.setQueryCaption(QueryCaptionType.EXTRACTIVE); + searchOptions.setQueryCaptionHighlightEnabled(false); + } + }); + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/ExtractKeywordsChatTemplate.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/ExtractKeywordsChatTemplate.java index c581f86..f7b3439 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/ExtractKeywordsChatTemplate.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/ExtractKeywordsChatTemplate.java @@ -1,81 +1,77 @@ -package com.microsoft.openai.samples.rag.retrieval; - -import com.azure.ai.openai.models.ChatMessage; -import com.azure.ai.openai.models.ChatRole; -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.common.ChatGPTConversation; -import com.microsoft.openai.samples.rag.common.ChatGPTUtils; - -import java.util.ArrayList; -import java.util.List; - -public class ExtractKeywordsChatTemplate { - - private final List conversationHistory = new ArrayList<>(); - - private String customPrompt = ""; - private Boolean replacePrompt = false; - - private static final String USER_CHAT_MESSAGE_TEMPLATE = """ - Generate a search query for the below conversation. - Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms. - Do not include any text inside [] or <<>> in the search query terms. - Do not enclose the search query in quotes or double quotes. - conversation: - %s - """ ; - - /** - * - * @param conversation conversation history - * @param sources domain specific sources to be used in the prompt - * @param customPrompt custom prompt to be injected in the existing promptTemplate or used to replace it - * @param replacePrompt if true, the customPrompt will replace the default promptTemplate, otherwise it will be appended - * to the default promptTemplate in the predefined section - */ - - private static final String GROUNDED_USER_QUESTION_TEMPLATE = """ - %s - Sources: - %s - """; - public ExtractKeywordsChatTemplate(ChatGPTConversation conversation) { - if (conversation == null || conversation.getMessages().isEmpty()) - throw new IllegalStateException("conversation cannot be null or empty"); - - String chatHistory = ChatGPTUtils.formatAsChatML(conversation.toOpenAIChatMessages()); - //Add system message - ChatMessage chatUserMessage = new ChatMessage(ChatRole.USER); - chatUserMessage.setContent(USER_CHAT_MESSAGE_TEMPLATE.formatted(chatHistory)); - - - this.conversationHistory.add(chatUserMessage); - - /** - //Add few shoot learning with chat - ChatMessage fewShotUser1Message = new ChatMessage(ChatRole.USER); - fewShotUser1Message.setContent("What are my health plans?"); - this.conversationHistory.add(fewShotUser1Message); - - ChatMessage fewShotAssistant1Message = new ChatMessage(ChatRole.ASSISTANT); - fewShotAssistant1Message.setContent("show available health plans"); - this.conversationHistory.add(fewShotAssistant1Message); - - ChatMessage fewShotUser2Message = new ChatMessage(ChatRole.USER); - fewShotUser2Message.setContent("does my plan cover cardio?"); - this.conversationHistory.add(fewShotUser2Message); - - ChatMessage fewShotAssistant2Message = new ChatMessage(ChatRole.ASSISTANT); - fewShotAssistant2Message.setContent("Health plan cardio coverage"); - this.conversationHistory.add(fewShotAssistant2Message); - **/ - //this.conversationHistory.addAll(conversation.toOpenAIChatMessages()); - } - - - public List getMessages() { - return this.conversationHistory; - } - - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.retrieval; + +import com.azure.ai.openai.models.ChatMessage; +import com.azure.ai.openai.models.ChatRole; +import com.microsoft.openai.samples.rag.common.ChatGPTConversation; +import com.microsoft.openai.samples.rag.common.ChatGPTUtils; +import java.util.ArrayList; +import java.util.List; + +public class ExtractKeywordsChatTemplate { + + private final List conversationHistory = new ArrayList<>(); + + private String customPrompt = ""; + private Boolean replacePrompt = false; + + private static final String USER_CHAT_MESSAGE_TEMPLATE = + """ + Generate a search query for the below conversation. + Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms. + Do not include any text inside [] or <<>> in the search query terms. + Do not enclose the search query in quotes or double quotes. + conversation: + %s + """; + + /** + * @param conversation conversation history + * @param sources domain specific sources to be used in the prompt + * @param customPrompt custom prompt to be injected in the existing promptTemplate or used to + * replace it + * @param replacePrompt if true, the customPrompt will replace the default promptTemplate, + * otherwise it will be appended to the default promptTemplate in the predefined section + */ + private static final String GROUNDED_USER_QUESTION_TEMPLATE = + """ + %s + Sources: + %s + """; + + public ExtractKeywordsChatTemplate(ChatGPTConversation conversation) { + if (conversation == null || conversation.getMessages().isEmpty()) + throw new IllegalStateException("conversation cannot be null or empty"); + + String chatHistory = ChatGPTUtils.formatAsChatML(conversation.toOpenAIChatMessages()); + // Add system message + ChatMessage chatUserMessage = new ChatMessage(ChatRole.USER); + chatUserMessage.setContent(USER_CHAT_MESSAGE_TEMPLATE.formatted(chatHistory)); + + this.conversationHistory.add(chatUserMessage); + + /** + * //Add few shoot learning with chat ChatMessage fewShotUser1Message = new + * ChatMessage(ChatRole.USER); fewShotUser1Message.setContent("What are my health plans?"); + * this.conversationHistory.add(fewShotUser1Message); + * + *

ChatMessage fewShotAssistant1Message = new ChatMessage(ChatRole.ASSISTANT); + * fewShotAssistant1Message.setContent("show available health plans"); + * this.conversationHistory.add(fewShotAssistant1Message); + * + *

ChatMessage fewShotUser2Message = new ChatMessage(ChatRole.USER); + * fewShotUser2Message.setContent("does my plan cover cardio?"); + * this.conversationHistory.add(fewShotUser2Message); + * + *

ChatMessage fewShotAssistant2Message = new ChatMessage(ChatRole.ASSISTANT); + * fewShotAssistant2Message.setContent("Health plan cardio coverage"); + * this.conversationHistory.add(fewShotAssistant2Message); + */ + // this.conversationHistory.addAll(conversation.toOpenAIChatMessages()); + } + + public List getMessages() { + return this.conversationHistory; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/FactsRetrieverProvider.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/FactsRetrieverProvider.java index 440f6b4..6bd8f0c 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/FactsRetrieverProvider.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/FactsRetrieverProvider.java @@ -1,33 +1,32 @@ -package com.microsoft.openai.samples.rag.retrieval; - -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import org.springframework.context.ApplicationContext; -import org.springframework.context.ApplicationContextAware; -import org.springframework.stereotype.Component; - -@Component -public class FactsRetrieverProvider implements ApplicationContextAware { - private ApplicationContext applicationContext; - - /** - * - * @param options rag options containing search types(Cognitive Semantic Search, Cognitive Vector Search, Cognitive Hybrid Search) - * Default is now cognitive search. - * @return retriever implementation - */ - public Retriever getFactsRetriever(RAGOptions options) { - //default to Cognitive Semantic Search for MVP. More useful in the future to support multiple retrieval systems (RedisSearch.Pinecone, etc) - switch (options.getRetrievalMode()){ - case vectors,hybrid,text: - return this.applicationContext.getBean(CognitiveSearchRetriever.class); - default: - return this.applicationContext.getBean(CognitiveSearchRetriever.class); - - - } - } - - public void setApplicationContext(ApplicationContext applicationContext) { - this.applicationContext = applicationContext; - } -} \ No newline at end of file +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.retrieval; + +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.stereotype.Component; + +@Component +public class FactsRetrieverProvider implements ApplicationContextAware { + private ApplicationContext applicationContext; + + /** + * @param options rag options containing search types(Cognitive Semantic Search, Cognitive + * Vector Search, Cognitive Hybrid Search) Default is now cognitive search. + * @return retriever implementation + */ + public Retriever getFactsRetriever(RAGOptions options) { + // default to Cognitive Semantic Search for MVP. More useful in the future to support + // multiple retrieval systems (RedisSearch.Pinecone, etc) + switch (options.getRetrievalMode()) { + case vectors, hybrid, text: + return this.applicationContext.getBean(CognitiveSearchRetriever.class); + default: + return this.applicationContext.getBean(CognitiveSearchRetriever.class); + } + } + + public void setApplicationContext(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } +} diff --git a/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/Retriever.java b/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/Retriever.java index 7bec0e1..3013fdc 100644 --- a/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/Retriever.java +++ b/app/backend/src/main/java/com/microsoft/openai/samples/rag/retrieval/Retriever.java @@ -1,14 +1,15 @@ -package com.microsoft.openai.samples.rag.retrieval; - -import com.microsoft.openai.samples.rag.approaches.ContentSource; -import com.microsoft.openai.samples.rag.approaches.RAGOptions; -import com.microsoft.openai.samples.rag.common.ChatGPTConversation; - -import java.util.List; - -public interface Retriever { - - List retrieveFromQuestion(String question, RAGOptions ragOptions); - - List retrieveFromConversation(ChatGPTConversation conversation, RAGOptions ragOptions); -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.retrieval; + +import com.microsoft.openai.samples.rag.approaches.ContentSource; +import com.microsoft.openai.samples.rag.approaches.RAGOptions; +import com.microsoft.openai.samples.rag.common.ChatGPTConversation; +import java.util.List; + +public interface Retriever { + + List retrieveFromQuestion(String question, RAGOptions ragOptions); + + List retrieveFromConversation( + ChatGPTConversation conversation, RAGOptions ragOptions); +} diff --git a/app/backend/src/test/java/com/microsoft/openai/samples/rag/AskAPITest.java b/app/backend/src/test/java/com/microsoft/openai/samples/rag/AskAPITest.java index 897af08..dd2f557 100644 --- a/app/backend/src/test/java/com/microsoft/openai/samples/rag/AskAPITest.java +++ b/app/backend/src/test/java/com/microsoft/openai/samples/rag/AskAPITest.java @@ -1,5 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + import com.azure.ai.openai.models.Choice; import com.azure.ai.openai.models.Completions; import com.azure.ai.openai.models.CompletionsOptions; @@ -13,70 +18,66 @@ import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; import com.microsoft.openai.samples.rag.test.utils.CognitiveSearchUnitTestUtils; import com.microsoft.openai.samples.rag.test.utils.OpenAIUnitTestUtils; +import java.net.URI; +import java.util.*; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.boot.test.web.client.TestRestTemplate; import org.springframework.test.context.ActiveProfiles; -import java.net.URI; -import java.util.*; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.when; - /** * This class tests the Ask API showcasing how you can mock azure services using mockito. - * CognitiveSearch and OpenAI models are immutable from the client usage perspective, so in order to create when/then condition with mockito - * we used a reflection hack to make some model private constructor public. @see CognitiveSearchUnitTestUtils and @see OpenAIUnitTestUtils for more info. + * CognitiveSearch and OpenAI models are immutable from the client usage perspective, so in order to + * create when/then condition with mockito we used a reflection hack to make some model private + * constructor public. @see CognitiveSearchUnitTestUtils and @see OpenAIUnitTestUtils for more info. */ @ActiveProfiles("test") @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) class AskAPITest { - @Autowired - private TestRestTemplate restTemplate; - @MockBean - private CognitiveSearchProxy cognitiveSearchProxyMock; - - @MockBean - private OpenAIProxy openAIProxyMock; - - /** after changing ask implementation from open ai completion to chatcompletion this test breaks. - * //TODO need to update mocks with ChatCompletions mocks instead of Completions mocks - @Test - void testExample() { - prepareMocks(); - - AskRequest askRequest = new AskRequest(); - askRequest.setQuestion("What does a Product Manager do?"); - askRequest.setApproach("rtr"); - - HttpEntity request = new HttpEntity<>(askRequest); - - ResponseEntity result = this.restTemplate.postForEntity(uri("/api/ask"), request, AskResponse.class); - - assertEquals(HttpStatus.OK, result.getStatusCode()); - assertNotNull(result.getBody()); - assertEquals("Product managers put items in roadmaps and backlogs", result.getBody().getAnswer()); - assertEquals(2, result.getBody().getDataPoints().size()); - assertEquals("cit1.pdf: This is a test document 1 for the unit test", result.getBody().getDataPoints().get(0)); - assertEquals("cit2.pdf: This is a test document 2 for the unit test", result.getBody().getDataPoints().get(1)); - } - */ - + @Autowired private TestRestTemplate restTemplate; + @MockBean private CognitiveSearchProxy cognitiveSearchProxyMock; + + @MockBean private OpenAIProxy openAIProxyMock; + + /** + * after changing ask implementation from open ai completion to chatcompletion this test breaks. + * //TODO need to update mocks with ChatCompletions mocks instead of Completions mocks @Test + * void testExample() { prepareMocks(); + * + *

AskRequest askRequest = new AskRequest(); askRequest.setQuestion("What does a Product + * Manager do?"); askRequest.setApproach("rtr"); + * + *

HttpEntity request = new HttpEntity<>(askRequest); + * + *

ResponseEntity result = this.restTemplate.postForEntity(uri("/api/ask"), + * request, AskResponse.class); + * + *

assertEquals(HttpStatus.OK, result.getStatusCode()); assertNotNull(result.getBody()); + * assertEquals("Product managers put items in roadmaps and backlogs", + * result.getBody().getAnswer()); assertEquals(2, result.getBody().getDataPoints().size()); + * assertEquals("cit1.pdf: This is a test document 1 for the unit test", + * result.getBody().getDataPoints().get(0)); assertEquals("cit2.pdf: This is a test document 2 + * for the unit test", result.getBody().getDataPoints().get(1)); } + */ private void prepareMocks() { SearchPagedIterable searchPagedIterable = buildSearchPagedIterableWithDocs(); - when(cognitiveSearchProxyMock.search(eq("What does a Product Manager do?"), any(SearchOptions.class), eq(Context.NONE))).thenReturn(searchPagedIterable); + when(cognitiveSearchProxyMock.search( + eq("What does a Product Manager do?"), + any(SearchOptions.class), + eq(Context.NONE))) + .thenReturn(searchPagedIterable); Completions mockedCompletions = buildCompletions(); - when(openAIProxyMock.getCompletions(any(CompletionsOptions.class))).thenReturn(mockedCompletions); + when(openAIProxyMock.getCompletions(any(CompletionsOptions.class))) + .thenReturn(mockedCompletions); } private Completions buildCompletions() { OpenAIUnitTestUtils utils = new OpenAIUnitTestUtils(); - Choice choice1 = utils.createChoice("Product managers put items in roadmaps and backlogs", 0); + Choice choice1 = + utils.createChoice("Product managers put items in roadmaps and backlogs", 0); List choices = List.of(choice1); @@ -91,7 +92,6 @@ private SearchPagedIterable buildSearchPagedIterableWithDocs() { cit1HasDoc.put("content", "This is a test document 1 for the unit test"); cit1HasDoc.put("sourcepage", "cit1.pdf"); - SearchDocument cit1Document = new SearchDocument(cit1HasDoc); SearchResult cit1SearchResult = new SearchResult(0.6); utils.setSearchDocument(cit1Document, cit1SearchResult); @@ -103,18 +103,21 @@ private SearchPagedIterable buildSearchPagedIterableWithDocs() { SearchResult cit2SearchResult = new SearchResult(0.6); utils.setSearchDocument(cit2Document, cit2SearchResult); - return new SearchPagedIterable(utils.getSearchPagedFlux(1, (inputInteger) -> List.of(cit1SearchResult, cit2SearchResult))); + return new SearchPagedIterable( + utils.getSearchPagedFlux( + 1, (inputInteger) -> List.of(cit1SearchResult, cit2SearchResult))); } - private URI uri(String path) { return restTemplate.getRestTemplate().getUriTemplateHandler().expand(path); } private CompletionsOptions buildCompletionsOptionsWithMockData() { - CompletionsOptions completionsOptions = new CompletionsOptions(new ArrayList<>(Arrays.asList(PROMPT_WITH_MOCKED_SOURCES))); + CompletionsOptions completionsOptions = + new CompletionsOptions(new ArrayList<>(Arrays.asList(PROMPT_WITH_MOCKED_SOURCES))); - // Due to a potential bug in using JVM 17 and java open SDK 1.0.0-beta.2, we need to provide default for all properties to avoid 404 bad Request on the server + // Due to a potential bug in using JVM 17 and java open SDK 1.0.0-beta.2, we need to provide + // default for all properties to avoid 404 bad Request on the server completionsOptions.setMaxTokens(1024); completionsOptions.setTemperature(0.3); completionsOptions.setStop(new ArrayList<>(Arrays.asList("\n"))); @@ -130,31 +133,31 @@ private CompletionsOptions buildCompletionsOptionsWithMockData() { return completionsOptions; } - private static final String PROMPT_WITH_MOCKED_SOURCES = """ + private static final String PROMPT_WITH_MOCKED_SOURCES = + """ You are an intelligent assistant helping Contoso Inc employees with their healthcare plan questions and employee handbook questions. Use 'you' to refer to the individual asking the questions even if they ask with 'I'. Answer the following question using only the data provided in the sources below. For tabular information return it as an html table. Do not return markdown format. Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. If you cannot answer using the sources below, say you don't know. - + ### Question: 'What is the deductible for the employee plan for a visit to Overlake in Bellevue?' - + Sources: cit1.pdf: This is a test document 1 for the unit test cit2.pdf: This is a test document 2 for the unit test - + Answer: In-network deductibles are $500 for employee and $1000 for family [info1.txt] and Overlake is in-network for the employee plan [info2.pdf][info4.pdf]. - + ### Question:'What does a Product Manager do??' - + Sources: %s - + Answer: """; - } diff --git a/app/backend/src/test/java/com/microsoft/openai/samples/rag/ChatAPITest.java b/app/backend/src/test/java/com/microsoft/openai/samples/rag/ChatAPITest.java index 5bc09ea..fa15149 100644 --- a/app/backend/src/test/java/com/microsoft/openai/samples/rag/ChatAPITest.java +++ b/app/backend/src/test/java/com/microsoft/openai/samples/rag/ChatAPITest.java @@ -1,42 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag; +import java.net.URI; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.web.client.TestRestTemplate; import org.springframework.test.context.ActiveProfiles; -import java.net.URI; - @ActiveProfiles("test") @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) class ChatAPITest { - @Autowired - private TestRestTemplate restTemplate; - - @Test - void testExample() throws Exception { - //TODO This test is failing due to external services dependencies. - // Explore springmockserver to mock the external services response based on specific requests. Sevices to mock are: Azure token, Cognitive Search, OpenAI chat apis -/** ChatTurn chatTurn = new ChatTurn(); - chatTurn.setUserText("What does a Product Manager do?"); - List chatTurns = new ArrayList<>(); - chatTurns.add(chatTurn); - - ChatRequest chatRequest = new ChatRequest(); - chatRequest.setChatHistory(chatTurns); - chatRequest.setApproach("rrr"); - HttpEntity request = new HttpEntity<>(chatRequest); - - ResponseEntity result = this.restTemplate.postForEntity(uri("/api/chat"), chatRequest, ChatResponse.class); - - assertEquals(HttpStatus.OK, result.getStatusCode()); -**/ - } - - private URI uri(String path) { - return restTemplate.getRestTemplate().getUriTemplateHandler().expand(path); - } - -} \ No newline at end of file + @Autowired private TestRestTemplate restTemplate; + + @Test + void testExample() throws Exception { + // TODO This test is failing due to external services dependencies. + // Explore springmockserver to mock the external services response based on specific + // requests. Sevices to mock are: Azure token, Cognitive Search, OpenAI chat apis + /** + * ChatTurn chatTurn = new ChatTurn(); chatTurn.setUserText("What does a Product Manager + * do?"); List chatTurns = new ArrayList<>(); chatTurns.add(chatTurn); + * + *

ChatRequest chatRequest = new ChatRequest(); chatRequest.setChatHistory(chatTurns); + * chatRequest.setApproach("rrr"); HttpEntity request = new + * HttpEntity<>(chatRequest); + * + *

ResponseEntity result = + * this.restTemplate.postForEntity(uri("/api/chat"), chatRequest, ChatResponse.class); + * + *

assertEquals(HttpStatus.OK, result.getStatusCode()); + */ + } + + private URI uri(String path) { + return restTemplate.getRestTemplate().getUriTemplateHandler().expand(path); + } +} diff --git a/app/backend/src/test/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImplTest.java b/app/backend/src/test/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImplTest.java index 4ea2391..74e5e73 100644 --- a/app/backend/src/test/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImplTest.java +++ b/app/backend/src/test/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImplTest.java @@ -1,5 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.approaches; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; + import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.core.credential.TokenCredential; import com.azure.search.documents.SearchAsyncClient; @@ -15,67 +19,62 @@ import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.test.context.ActiveProfiles; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertThrows; - @ActiveProfiles("test") @SpringBootTest class RAGApproachFactorySpringBootImplTest { - @MockBean - private CognitiveSearchProxy cognitiveSearchProxyMock; + @MockBean private CognitiveSearchProxy cognitiveSearchProxyMock; - @MockBean - private TokenCredential tokenCredential; - - @MockBean - private SearchAsyncClient searchAsyncClient; - @MockBean - private OpenAIAsyncClient openAIAsyncClient; - - @Autowired - private RAGApproachFactorySpringBootImpl ragApproachFactory; + @MockBean private TokenCredential tokenCredential; + @MockBean private SearchAsyncClient searchAsyncClient; + @MockBean private OpenAIAsyncClient openAIAsyncClient; + @Autowired private RAGApproachFactorySpringBootImpl ragApproachFactory; @Test void testCreateApproachWithJavaPlain() { - RAGApproach approach = ragApproachFactory.createApproach("jos", RAGType.ASK,null); + RAGApproach approach = ragApproachFactory.createApproach("jos", RAGType.ASK, null); assertInstanceOf(PlainJavaAskApproach.class, approach); } @Test void testCreateApproachWithJavaSemanticKernelMemory() { - RAGApproach approach = ragApproachFactory.createApproach("jsk", RAGType.ASK,null); + RAGApproach approach = ragApproachFactory.createApproach("jsk", RAGType.ASK, null); assertInstanceOf(JavaSemanticKernelWithMemoryApproach.class, approach); } + @Test void testCreateApproachWithJavaSemanticKernelChain() { var ragOptions = new RAGOptions.Builder().semanticKernelMode("chains").build(); - RAGApproach approach = ragApproachFactory.createApproach("jskp", RAGType.ASK,ragOptions); + RAGApproach approach = ragApproachFactory.createApproach("jskp", RAGType.ASK, ragOptions); assertInstanceOf(JavaSemanticKernelChainsApproach.class, approach); } + @Test void testCreateApproachWithJavaSemanticKernelPlanner() { var ragOptions = new RAGOptions.Builder().semanticKernelMode("planner").build(); - RAGApproach approach = ragApproachFactory.createApproach("jskp", RAGType.ASK,ragOptions); + RAGApproach approach = ragApproachFactory.createApproach("jskp", RAGType.ASK, ragOptions); assertInstanceOf(JavaSemanticKernelPlannerApproach.class, approach); } @Test void testChatCreateApproachWithChat() { - RAGApproach approach = ragApproachFactory.createApproach("jos", RAGType.CHAT,null); + RAGApproach approach = ragApproachFactory.createApproach("jos", RAGType.CHAT, null); assertInstanceOf(PlainJavaChatApproach.class, approach); } @Test void testCreateApproachWithInvalidApproachName() { - assertThrows(IllegalArgumentException.class, () -> ragApproachFactory.createApproach("invalid", RAGType.ASK,null)); + assertThrows( + IllegalArgumentException.class, + () -> ragApproachFactory.createApproach("invalid", RAGType.ASK, null)); } @Test void testCreateApproachWithInvalidCombination() { - assertThrows(IllegalArgumentException.class, () -> ragApproachFactory.createApproach("rtr", RAGType.CHAT,null)); + assertThrows( + IllegalArgumentException.class, + () -> ragApproachFactory.createApproach("rtr", RAGType.CHAT, null)); } - -} \ No newline at end of file +} diff --git a/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/config/ProxyMockConfiguration.java b/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/config/ProxyMockConfiguration.java index 53ea5d7..49f91c6 100644 --- a/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/config/ProxyMockConfiguration.java +++ b/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/config/ProxyMockConfiguration.java @@ -1,41 +1,41 @@ -package com.microsoft.openai.samples.rag.test.config; - -import com.azure.ai.openai.OpenAIAsyncClient; -import com.microsoft.openai.samples.rag.proxy.BlobStorageProxy; -import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; -import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; -import org.mockito.Mockito; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Primary; -import org.springframework.context.annotation.Profile; - -@Profile("test") -@Configuration -public class ProxyMockConfiguration { - - @Bean - @Primary - public CognitiveSearchProxy mockedCognitiveSearchProxy() { - return Mockito.mock(CognitiveSearchProxy.class); - } - - @Bean - @Primary - public OpenAIProxy mockedOpenAISearchProxy() { - return Mockito.mock(OpenAIProxy.class); - } - - @Bean - @Primary - public BlobStorageProxy mockedBlobStorageProxy() { - return Mockito.mock(BlobStorageProxy.class); - } - - @Bean - @Primary - public OpenAIAsyncClient mockedOpenAIAsynchClient() { - return Mockito.mock(OpenAIAsyncClient.class); - } - -} +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.openai.samples.rag.test.config; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.microsoft.openai.samples.rag.proxy.BlobStorageProxy; +import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; +import com.microsoft.openai.samples.rag.proxy.OpenAIProxy; +import org.mockito.Mockito; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; +import org.springframework.context.annotation.Profile; + +@Profile("test") +@Configuration +public class ProxyMockConfiguration { + + @Bean + @Primary + public CognitiveSearchProxy mockedCognitiveSearchProxy() { + return Mockito.mock(CognitiveSearchProxy.class); + } + + @Bean + @Primary + public OpenAIProxy mockedOpenAISearchProxy() { + return Mockito.mock(OpenAIProxy.class); + } + + @Bean + @Primary + public BlobStorageProxy mockedBlobStorageProxy() { + return Mockito.mock(BlobStorageProxy.class); + } + + @Bean + @Primary + public OpenAIAsyncClient mockedOpenAIAsynchClient() { + return Mockito.mock(OpenAIAsyncClient.class); + } +} diff --git a/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/utils/CognitiveSearchUnitTestUtils.java b/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/utils/CognitiveSearchUnitTestUtils.java index 519979a..242130c 100644 --- a/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/utils/CognitiveSearchUnitTestUtils.java +++ b/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/utils/CognitiveSearchUnitTestUtils.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.test.utils; import com.azure.core.http.HttpHeaderName; @@ -13,72 +14,108 @@ import com.azure.search.documents.models.SearchResult; import com.azure.search.documents.util.SearchPagedFlux; import com.azure.search.documents.util.SearchPagedResponse; -import reactor.core.publisher.Mono; - import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; +import reactor.core.publisher.Mono; public class CognitiveSearchUnitTestUtils { - private final HttpHeaders httpHeaders = new HttpHeaders() - .set(HttpHeaderName.fromString("header1"), "value1") - .set(HttpHeaderName.fromString("header2"), "value2"); + private final HttpHeaders httpHeaders = + new HttpHeaders() + .set(HttpHeaderName.fromString("header1"), "value1") + .set(HttpHeaderName.fromString("header2"), "value2"); private final HttpRequest httpRequest = new HttpRequest(HttpMethod.GET, "http://localhost"); private final String deserializedHeaders = "header1,value1,header2,value2"; private List> pagedResponses; - public PagedFlux getPagedFlux(int numberOfPages, Function> valueSupplier) { - List> pagedResponses = IntStream.range(0, numberOfPages) - .boxed() - .map(i -> - createPagedResponse(httpRequest, httpHeaders, deserializedHeaders, numberOfPages, valueSupplier, i)) - .collect(Collectors.toList()); - - return new PagedFlux<>(() -> pagedResponses.isEmpty() ? Mono.empty() : Mono.just(pagedResponses.get(0)), - continuationToken -> getNextPage(continuationToken, pagedResponses)); + public PagedFlux getPagedFlux( + int numberOfPages, Function> valueSupplier) { + List> pagedResponses = + IntStream.range(0, numberOfPages) + .boxed() + .map( + i -> + createPagedResponse( + httpRequest, + httpHeaders, + deserializedHeaders, + numberOfPages, + valueSupplier, + i)) + .collect(Collectors.toList()); + + return new PagedFlux<>( + () -> pagedResponses.isEmpty() ? Mono.empty() : Mono.just(pagedResponses.get(0)), + continuationToken -> getNextPage(continuationToken, pagedResponses)); } - public SearchPagedFlux getSearchPagedFlux(int numberOfPages, Function> valueSupplier) { - List searchPagedResponses = IntStream.range(0, numberOfPages) - .boxed() - .map(i -> - createSearchPagedResponse(httpRequest, httpHeaders, deserializedHeaders, numberOfPages, valueSupplier, i)) - .collect(Collectors.toList()); - - return new SearchPagedFlux(() -> searchPagedResponses.isEmpty() ? Mono.empty() : Mono.just(searchPagedResponses.get(0)), - continuationToken -> getNextSearchPage(continuationToken, searchPagedResponses)); + public SearchPagedFlux getSearchPagedFlux( + int numberOfPages, Function> valueSupplier) { + List searchPagedResponses = + IntStream.range(0, numberOfPages) + .boxed() + .map( + i -> + createSearchPagedResponse( + httpRequest, + httpHeaders, + deserializedHeaders, + numberOfPages, + valueSupplier, + i)) + .collect(Collectors.toList()); + + return new SearchPagedFlux( + () -> + searchPagedResponses.isEmpty() + ? Mono.empty() + : Mono.just(searchPagedResponses.get(0)), + continuationToken -> getNextSearchPage(continuationToken, searchPagedResponses)); } - private PagedResponseBase createPagedResponse(HttpRequest httpRequest, HttpHeaders headers, - String deserializedHeaders, - int numberOfPages, - Function> valueSupplier, - int i) { - return new PagedResponseBase<>(httpRequest, 200, headers, valueSupplier.apply(i), - (i < numberOfPages - 1) ? String.valueOf(i + 1) : null, - deserializedHeaders); + private PagedResponseBase createPagedResponse( + HttpRequest httpRequest, + HttpHeaders headers, + String deserializedHeaders, + int numberOfPages, + Function> valueSupplier, + int i) { + return new PagedResponseBase<>( + httpRequest, + 200, + headers, + valueSupplier.apply(i), + (i < numberOfPages - 1) ? String.valueOf(i + 1) : null, + deserializedHeaders); } - private SearchPagedResponse createSearchPagedResponse(HttpRequest httpRequest, HttpHeaders headers, - String deserializedHeaders, - int numberOfPages, - Function> valueSupplier, - int i) { + private SearchPagedResponse createSearchPagedResponse( + HttpRequest httpRequest, + HttpHeaders headers, + String deserializedHeaders, + int numberOfPages, + Function> valueSupplier, + int i) { String continuationToken = (i < numberOfPages - 1) ? String.valueOf(i + 1) : null; - PagedResponseBase pagedResponseBase = new PagedResponseBase<>(httpRequest, 200, headers, valueSupplier.apply(i), - continuationToken, - deserializedHeaders); + PagedResponseBase pagedResponseBase = + new PagedResponseBase<>( + httpRequest, + 200, + headers, + valueSupplier.apply(i), + continuationToken, + deserializedHeaders); return new SearchPagedResponse(pagedResponseBase, continuationToken, null, 0L, 0.0); } - private Mono> getNextPage(String continuationToken, - List> pagedResponses) { + private Mono> getNextPage( + String continuationToken, List> pagedResponses) { if (continuationToken == null || continuationToken.isEmpty()) { return Mono.empty(); @@ -92,8 +129,8 @@ private Mono> getNextPage(String continuationToken, return Mono.just(pagedResponses.get(parsedToken)); } - private Mono getNextSearchPage(String continuationToken, - List searchPagedResponses) { + private Mono getNextSearchPage( + String continuationToken, List searchPagedResponses) { if (continuationToken == null || continuationToken.isEmpty()) { return Mono.empty(); @@ -113,8 +150,11 @@ public void setSearchDocument(SearchDocument searchDocument, SearchResult search Method setJsonSerializerMethod; DefaultJsonSerializer defaultJsonSerializer = new DefaultJsonSerializer(); try { - setadditionalPropertiesMethod = SearchResult.class.getDeclaredMethod("setAdditionalProperties", SearchDocument.class); - setJsonSerializerMethod = SearchResult.class.getDeclaredMethod("setJsonSerializer", JsonSerializer.class); + setadditionalPropertiesMethod = + SearchResult.class.getDeclaredMethod( + "setAdditionalProperties", SearchDocument.class); + setJsonSerializerMethod = + SearchResult.class.getDeclaredMethod("setJsonSerializer", JsonSerializer.class); } catch (NoSuchMethodException e) { throw new RuntimeException("Unable to find method in SearchResult class", e); } @@ -126,7 +166,5 @@ public void setSearchDocument(SearchDocument searchDocument, SearchResult search } catch (IllegalAccessException | InvocationTargetException e) { throw new RuntimeException(e); } - } - } diff --git a/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/utils/OpenAIUnitTestUtils.java b/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/utils/OpenAIUnitTestUtils.java index e203381..5de7cba 100644 --- a/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/utils/OpenAIUnitTestUtils.java +++ b/app/backend/src/test/java/com/microsoft/openai/samples/rag/test/utils/OpenAIUnitTestUtils.java @@ -1,7 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.openai.samples.rag.test.utils; import com.azure.ai.openai.models.*; - import java.lang.reflect.Constructor; import java.time.OffsetDateTime; import java.util.List; @@ -11,7 +11,12 @@ public class OpenAIUnitTestUtils { public Choice createChoice(String text, int index) { Constructor pcc; try { - pcc = Choice.class.getDeclaredConstructor(String.class, int.class, CompletionsLogProbabilityModel.class, CompletionsFinishReason.class); + pcc = + Choice.class.getDeclaredConstructor( + String.class, + int.class, + CompletionsLogProbabilityModel.class, + CompletionsFinishReason.class); } catch (NoSuchMethodException e) { throw new RuntimeException("No constructor found for Choice.class", e); } @@ -26,7 +31,8 @@ public Choice createChoice(String text, int index) { return choice; } - public CompletionsUsage createCompletionUsage(int completionTokens, int promptTokens, int totalTokens) { + public CompletionsUsage createCompletionUsage( + int completionTokens, int promptTokens, int totalTokens) { Constructor pcc; try { pcc = CompletionsUsage.class.getDeclaredConstructor(int.class, int.class, int.class); @@ -51,7 +57,9 @@ public Completions createCompletions(List choices) { public Completions createCompletions(List choices, CompletionsUsage completionsUsage) { Constructor pcc; try { - pcc = Completions.class.getDeclaredConstructor(String.class, OffsetDateTime.class, List.class, CompletionsUsage.class); + pcc = + Completions.class.getDeclaredConstructor( + String.class, OffsetDateTime.class, List.class, CompletionsUsage.class); } catch (NoSuchMethodException e) { throw new RuntimeException("No constructor found for Completions.class", e); } @@ -65,5 +73,4 @@ public Completions createCompletions(List choices, CompletionsUsage comp return choice; } - }