Skip to content

Commit

Permalink
Merge in latest UI from python and update models to match
Browse files Browse the repository at this point in the history
  • Loading branch information
johnoliver committed Oct 12, 2023
1 parent c6936aa commit 949717a
Show file tree
Hide file tree
Showing 71 changed files with 4,779 additions and 1,887 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ public Builder top(Integer top) {
return this;
}


public RAGOptions build() {
RAGOptions ragOptions = new RAGOptions();
ragOptions.retrievalMode = this.retrievalMode;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.microsoft.openai.samples.rag.ask.approaches.semantickernel;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.microsoft.openai.samples.rag.approaches.ContentSource;
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
Expand All @@ -9,15 +10,16 @@
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.SKBuilders;
import com.microsoft.semantickernel.orchestration.SKContext;
import com.microsoft.semantickernel.planner.sequentialplanner.SequentialPlanner;
import com.microsoft.semantickernel.planner.sequentialplanner.SequentialPlannerRequestSettings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Use Java Semantic Kernel framework with semantic and native functions chaining. It uses an imperative style for AI orchestration through semantic kernel functions chaining.
Expand Down Expand Up @@ -59,6 +61,8 @@ public RAGResponse run(String question, RAGOptions options) {
question,
semanticKernel.getSkill("InformationFinder").getFunction("Search", null)).block();

var sources = formSourcesList(searchContext.getResult());

var answerVariables = SKBuilders.variables()
.withVariable("sources", searchContext.getResult())
.withVariable("input", question)
Expand All @@ -70,12 +74,33 @@ public RAGResponse run(String question, RAGOptions options) {
return new RAGResponse.Builder()
.prompt("Prompt is managed by Semantic Kernel")
.answer(answerExecutionContext.getResult())
.sources(sources)
.sourcesAsText(searchContext.getResult())
.question(question)
.build();

}

private List<ContentSource> formSourcesList(String result) {
if (result == null) {
return Collections.emptyList();
}
return Arrays.stream(result
.split("\n"))
.map(source -> {
String[] split = source.split(":", 2);
if (split.length >= 2) {
var sourceName = split[0].trim();
var sourceContent = split[1].trim();
return new ContentSource(sourceName, sourceContent);
} else {
return null;
}
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
}

private Kernel buildSemanticKernel( RAGOptions options) {
Kernel kernel = SKBuilders.kernel()
.withDefaultAIService(SKBuilders.chatCompletion()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.azure.core.credential.TokenCredential;
import com.azure.search.documents.SearchAsyncClient;
import com.azure.search.documents.SearchDocument;
import com.microsoft.openai.samples.rag.approaches.ContentSource;
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
Expand All @@ -22,6 +23,7 @@

import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* Accomplish the same task as in the PlainJavaAskApproach approach but using Semantic Kernel framework:
Expand Down Expand Up @@ -78,6 +80,7 @@ public RAGResponse run(String question, RAGOptions options) {
LOGGER.info("Total {} sources found in cognitive vector store for search query[{}]", memoryResult.size(), question);

String sources = buildSourcesText(memoryResult);
List<ContentSource> sourcesList = buildSources(memoryResult);

SKContext skcontext = SKBuilders.context().build()
.setVariable("sources", sources)
Expand All @@ -90,12 +93,25 @@ public RAGResponse run(String question, RAGOptions options) {
//.prompt(plan.toPlanString())
.prompt("placeholders for prompt")
.answer(result.block().getResult())
.sources(sourcesList)
.sourcesAsText(sources)
.question(question)
.build();

}

private List<ContentSource> buildSources(List<MemoryQueryResult> memoryResult) {
return memoryResult
.stream()
.map(result -> {
return new ContentSource(
result.getMetadata().getId(),
result.getMetadata().getText()
);
})
.collect(Collectors.toList());
}

private String buildSourcesText(List<MemoryQueryResult> memoryResult) {
StringBuilder sourcesContentBuffer = new StringBuilder();
memoryResult.stream().forEach(memory -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
package com.microsoft.openai.samples.rag.ask.controller;

import com.microsoft.openai.samples.rag.approaches.*;
import com.microsoft.openai.samples.rag.controller.Overrides;
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
import com.microsoft.openai.samples.rag.approaches.RAGApproachFactory;
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
import com.microsoft.openai.samples.rag.approaches.RAGType;
import com.microsoft.openai.samples.rag.chat.controller.ChatAppRequest;
import com.microsoft.openai.samples.rag.chat.controller.ChatResponse;
import com.microsoft.openai.samples.rag.chat.controller.ResponseChoice;
import com.microsoft.openai.samples.rag.chat.controller.ResponseContext;
import com.microsoft.openai.samples.rag.chat.controller.ResponseMessage;
import com.microsoft.openai.samples.rag.common.ChatGPTMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
Expand All @@ -11,7 +20,7 @@
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

@RestController
Expand All @@ -25,59 +34,60 @@ public class AskController {
}

@PostMapping("/api/ask")
public ResponseEntity<AskResponse> openAIAsk(@RequestBody AskRequest askRequest) {
LOGGER.info("Received request for ask api with question [{}] and approach[{}]", askRequest.getQuestion(), askRequest.getApproach());
public ResponseEntity<ChatResponse> openAIAsk(@RequestBody ChatAppRequest askRequest) {
String question = askRequest.messages().get(askRequest.messages().size() - 1).content();
LOGGER.info("Received request for ask api with question [{}] and approach[{}]", question, askRequest.approach());

if (!StringUtils.hasText(askRequest.getApproach())) {
if (!StringUtils.hasText(askRequest.approach())) {
LOGGER.warn("approach cannot be null in ASK request");
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null);
}

if (!StringUtils.hasText(askRequest.getQuestion())) {
if (!StringUtils.hasText(question)) {
LOGGER.warn("question cannot be null in ASK request");
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null);
}

var ragOptions = new RAGOptions.Builder()
.retrievialMode(askRequest.getOverrides().getRetrievalMode())
.semanticKernelMode(askRequest.getOverrides().getSemantickKernelMode())
.semanticRanker(askRequest.getOverrides().isSemanticRanker())
.semanticCaptions(askRequest.getOverrides().isSemanticCaptions())
.excludeCategory(askRequest.getOverrides().getExcludeCategory())
.promptTemplate(askRequest.getOverrides().getPromptTemplate())
.top(askRequest.getOverrides().getTop())
.retrievialMode(askRequest.context().overrides().retrieval_mode().name())
.semanticKernelMode(askRequest.context().overrides().semantic_kernel_mode())
.semanticRanker(askRequest.context().overrides().semantic_ranker())
.semanticCaptions(askRequest.context().overrides().semantic_captions())
.excludeCategory(askRequest.context().overrides().exclude_category())
.promptTemplate(askRequest.context().overrides().prompt_template())
.top(askRequest.context().overrides().top())
.build();

RAGApproach<String, RAGResponse> ragApproach = ragApproachFactory.createApproach(askRequest.getApproach(), RAGType.ASK, ragOptions);
RAGApproach<String, RAGResponse> ragApproach = ragApproachFactory.createApproach(askRequest.approach(), RAGType.ASK, ragOptions);

//set empty overrides if not provided
if (askRequest.getOverrides() == null) {
askRequest.setOverrides(new Overrides());
}



return ResponseEntity.ok(buildAskResponse(ragApproach.run(askRequest.getQuestion(), ragOptions)));
return ResponseEntity.ok(buildChatResponse(ragApproach.run(question, ragOptions)));
}

private AskResponse buildAskResponse(RAGResponse ragResponse) {
var askResponse = new AskResponse();
private ChatResponse buildChatResponse(RAGResponse ragResponse) {
List<String> dataPoints = Collections.emptyList();

askResponse.setAnswer(ragResponse.getAnswer());
List<String> dataPoints;
if (ragResponse.getSourcesAsText() != null && !ragResponse.getSourcesAsText().isEmpty()) {
dataPoints = Arrays.asList(ragResponse.getSourcesAsText().split("\n"));
} else {
if (ragResponse.getSources() != null) {
dataPoints = ragResponse.getSources().stream()
.map(source -> source.getSourceName() + ": " + source.getSourceContent())
.toList();
.map(source -> source.getSourceName() + ": " + source.getSourceContent())
.toList();
}

askResponse.setDataPoints(dataPoints);

askResponse.setThoughts("Question:<br>" + ragResponse.getQuestion() + "<br><br>Prompt:<br>" + ragResponse.getPrompt().replace("\n", "<br>"));

return askResponse;
String thoughts = "Question:<br>" + ragResponse.getQuestion() + "<br><br>Prompt:<br>" + ragResponse.getPrompt().replace("\n", "<br>");

return new ChatResponse(
List.of(
new ResponseChoice(
0,
new ResponseMessage(
ragResponse.getAnswer(),
ChatGPTMessage.ChatRole.ASSISTANT.toString()
),
new ResponseContext(
thoughts,
dataPoints
)
)
)
);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.microsoft.openai.samples.rag.chat.controller;

import java.util.List;

public record ChatAppRequest(
List<ResponseMessage> messages,
ChatAppRequestContext context,
boolean stream,
String approach
) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.microsoft.openai.samples.rag.chat.controller;

public record ChatAppRequestContext(ChatAppRequestOverrides overrides) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.microsoft.openai.samples.rag.chat.controller;

import com.microsoft.openai.samples.rag.approaches.RetrievalMode;

public record ChatAppRequestOverrides(
RetrievalMode retrieval_mode,
boolean semantic_ranker,
boolean semantic_captions,
String exclude_category,
int top,
float temperature,
String prompt_template,
String prompt_template_prefix,
String prompt_template_suffix,
boolean suggest_followup_questions,
boolean use_oid_security_filter,
boolean use_groups_security_filter,
String semantic_kernel_mode
) {
}
Loading

0 comments on commit 949717a

Please sign in to comment.