Skip to content

Commit

Permalink
Multi LLM prompt studio fixes (#515)
Browse files Browse the repository at this point in the history
* multi llm prompt studio fixes

* code refactor

---------

Co-authored-by: Rahul Johny <[email protected]>
  • Loading branch information
jagadeeswaran-zipstack and johnyrahul authored Jul 24, 2024
1 parent 3f01d13 commit 13a4edc
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Any, Optional

from django.core.exceptions import ObjectDoesNotExist
from prompt_studio.prompt_profile_manager.models import ProfileManager
from prompt_studio.prompt_studio.models import ToolStudioPrompt
from prompt_studio.prompt_studio_core.exceptions import (
Expand Down Expand Up @@ -144,3 +145,39 @@ def fetch_default_llm_profile(tool: CustomTool) -> ProfileManager:
return ProfileManager.get_default_llm_profile(tool=tool)
except DefaultProfileError:
raise DefaultProfileError("Default ProfileManager does not exist.")

@staticmethod
def fetch_default_response(
tool_studio_prompts: list[ToolStudioPrompt], document_manager_id: str
) -> dict[str, Any]:
# Initialize the result dictionary
result: dict[str, Any] = {}
# Iterate over ToolStudioPrompt records
for tool_prompt in tool_studio_prompts:
if tool_prompt.prompt_type == PSOMKeys.NOTES:
continue
prompt_id = str(tool_prompt.prompt_id)
profile_manager_id = tool_prompt.profile_manager_id

# If profile_manager is not set, skip this record
if not profile_manager_id:
result[tool_prompt.prompt_key] = ""
continue

try:
queryset = PromptStudioOutputManager.objects.filter(
prompt_id=prompt_id,
profile_manager=profile_manager_id,
is_single_pass_extract=False,
document_manager_id=document_manager_id,
)

if not queryset.exists():
result[tool_prompt.prompt_key] = ""
continue

for output in queryset:
result[tool_prompt.prompt_key] = output.output
except ObjectDoesNotExist:
result[tool_prompt.prompt_key] = ""
return result
41 changes: 11 additions & 30 deletions backend/prompt_studio/prompt_studio_output_manager/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
PromptOutputManagerErrorMessage,
PromptStudioOutputManagerKeys,
)
from prompt_studio.prompt_studio_output_manager.output_manager_helper import (
OutputManagerHelper,
)
from prompt_studio.prompt_studio_output_manager.serializers import (
PromptStudioOutputSerializer,
)
Expand Down Expand Up @@ -68,38 +71,16 @@ def get_output_for_tool_default(self, request: HttpRequest) -> Response:

try:
# Fetch ToolStudioPrompt records based on tool_id
tool_studio_prompts = ToolStudioPrompt.objects.filter(tool_id=tool_id)
tool_studio_prompts = ToolStudioPrompt.objects.filter(
tool_id=tool_id
).order_by("sequence_number")
except ObjectDoesNotExist:
raise APIException(detail=tool_not_found, code=400)

# Initialize the result dictionary
result: dict[str, Any] = {}

# Iterate over ToolStudioPrompt records
for tool_prompt in tool_studio_prompts:
prompt_id = str(tool_prompt.prompt_id)
profile_manager_id = str(tool_prompt.profile_manager.profile_id)

# If profile_manager is not set, skip this record
if not profile_manager_id:
result[tool_prompt.prompt_key] = ""
continue

try:
queryset = PromptStudioOutputManager.objects.filter(
prompt_id=prompt_id,
profile_manager=profile_manager_id,
is_single_pass_extract=False,
document_manager_id=document_manager_id,
)

if not queryset.exists():
result[tool_prompt.prompt_key] = ""
continue

for output in queryset:
result[tool_prompt.prompt_key] = output.output
except ObjectDoesNotExist:
result[tool_prompt.prompt_key] = ""
# Invoke helper method to frame and fetch default response.
result: dict[str, Any] = OutputManagerHelper.fetch_default_response(
tool_studio_prompts=tool_studio_prompts,
document_manager_id=document_manager_id,
)

return Response(result, status=status.HTTP_200_OK)
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ function CombinedOutput({ docId, setFilledFields }) {
if (!docId || isSinglePassExtractLoading) {
return;
}
if (singlePassExtractMode && activeKey === "0") {
setActiveKey("1");
}

let filledFields = 0;
setIsOutputLoading(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ function JsonView({
)}
{adapterData.map((adapter, index) => (
<TabPane
tab={<span>{adapter.llm_model}</span>}
tab={<span>{adapter?.llm_model || adapter?.profile_name}</span>}
key={(index + 1)?.toString()}
/>
))}
</Tabs>
<div className="combined-op-segment"></div>
</div>
<div className="combined-op-divider" />
<ProfileInfoBar profileId={selectedProfile} profiles={llmProfiles} />
{activeKey !== "0" && (
<ProfileInfoBar profileId={selectedProfile} profiles={llmProfiles} />
)}
<div className="combined-op-body code-snippet">
{combinedOutput && (
<pre className="line-numbers width-100">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate";
import "./OutputForDocModal.css";
import {
displayPromptResult,
getDocIdFromKey,
getLLMModelNamesForProfiles,
} from "../../../helpers/GetStaticData";
import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader";
Expand Down Expand Up @@ -117,13 +118,63 @@ function OutputForDocModal({

const updatePromptOutput = (data) => {
setPromptOutputs((prev) => {
// If data is provided, use it; otherwise, create a copy of the previous state
const updatedPromptOutput = data || [...prev];
const updatedPromptOutput = getUpdatedPromptOutput(data, prev);
const keys = Object.keys(docOutputs);

keys.forEach((key) => {
const docId = getDocIdFromKey(key);
updatePromptOutputInstance(updatedPromptOutput, docId, key);
});

return updatedPromptOutput;
});
};

const getUpdatedPromptOutput = (data, prev) => {
return data || [...prev];
};

const updatePromptOutputInstance = (updatedPromptOutput, docId, key) => {
const index = findPromptOutputIndex(updatedPromptOutput, docId);
const promptOutputInstance = createOrUpdatePromptOutputInstance(
updatedPromptOutput,
docId,
key,
index
);

if (index > -1) {
updatedPromptOutput[index] = promptOutputInstance;
} else {
updatedPromptOutput.push(promptOutputInstance);
}
};

const findPromptOutputIndex = (updatedPromptOutput, docId) => {
return updatedPromptOutput.findIndex(
(promptOutput) => promptOutput?.document_manager === docId
);
};

const createOrUpdatePromptOutputInstance = (
updatedPromptOutput,
docId,
key,
index
) => {
let promptOutputInstance = {};

if (index > -1) {
promptOutputInstance = { ...updatedPromptOutput[index] };
}

promptOutputInstance["document_manager"] = docId;
promptOutputInstance["output"] = docOutputs[key]?.output;
promptOutputInstance["isLoading"] = docOutputs[key]?.isLoading || false;

return promptOutputInstance;
};

const getAdapterInfo = () => {
axiosPrivate
.get(`/api/v1/unstract/${sessionDetails.orgId}/adapter/?adapter_type=LLM`)
Expand Down Expand Up @@ -172,7 +223,7 @@ function OutputForDocModal({
const output = data.find(
(outputValue) => outputValue?.document_manager === item?.document_id
);

const key = `${output?.prompt_id}__${output?.document_manager}__${output?.profile_manager}`;
let status = outputStatus.fail;
let message = displayPromptResult(output?.output, true);

Expand All @@ -188,6 +239,7 @@ function OutputForDocModal({
status = outputStatus.yet_to_process;
message = "Yet to process";
}
const isLoading = docOutputs.find((obj) => obj?.key === key)?.isLoading;

const result = {
key: item?.document_id,
Expand All @@ -205,7 +257,7 @@ function OutputForDocModal({
),
value: (
<>
{output?.isLoading ? (
{isLoading ? (
<SpinnerLoader align="default" />
) : (
<Typography.Text>
Expand Down Expand Up @@ -233,7 +285,7 @@ function OutputForDocModal({

const handleTabChange = (key) => {
if (key === "0") {
setSelectedProfile(profileManagerId);
setSelectedProfile(defaultLlmProfile);
} else {
setSelectedProfile(adapterData[key - 1]?.profile_id);
}
Expand Down Expand Up @@ -281,7 +333,7 @@ function OutputForDocModal({
<TabPane tab={<span>Default</span>} key={"0"}></TabPane>
{adapterData?.map((adapter, index) => (
<TabPane
tab={<span>{adapter?.llm_model}</span>}
tab={<span>{adapter?.llm_model || adapter?.profile_name}</span>}
key={(index + 1)?.toString()}
></TabPane>
))}
Expand Down
20 changes: 7 additions & 13 deletions frontend/src/components/custom-tools/prompt-card/PromptCard.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ function PromptCard({
return;
}
setIsCoverageLoading(true);
setCoverage(0);
setCoverageTotal(0);
resetInfoMsgs();

Expand Down Expand Up @@ -339,9 +338,6 @@ function PromptCard({
.then((res) => {
const data = res?.data?.output;
const value = data[promptDetails?.prompt_key];
if (value || value === 0) {
setCoverage((prev) => prev + 1);
}
handleDocOutputs(
docId,
promptDetails?.prompt_id,
Expand Down Expand Up @@ -387,13 +383,7 @@ function PromptCard({
.then((res) => {
const data = res?.data?.output;
const value = data[promptDetails?.prompt_key];
if (value || value === 0) {
updateDocCoverage(
promptDetails?.prompt_id,
profileManagerId,
docId
);
}
updateDocCoverage(promptDetails?.prompt_id, profileManagerId, docId);
handleDocOutputs(
docId,
promptDetails?.prompt_id,
Expand Down Expand Up @@ -737,9 +727,13 @@ function PromptCard({

if (singlePassExtractMode) {
const tokenUsageId = `single_pass__${defaultLlmProfile}__${selectedDoc?.document_id}`;
const usage = data?.find((item) => item?.run_id !== undefined);
const usage = data?.find(
(item) =>
item?.profile_manager === defaultLlmProfile &&
item?.document_manager === selectedDoc?.document_id
);

if (!tokenUsage[tokenUsageId] && usage) {
if (usage) {
setTokenUsage(tokenUsageId, usage?.token_usage);
}
} else {
Expand Down
13 changes: 13 additions & 0 deletions frontend/src/helpers/GetStaticData.js
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,18 @@ const pollForCompletion = (
return recursivePoll();
};

function getDocIdFromKey(key) {
// Split the key by '__'
const parts = key.split("__");

// Return the docId part, which is the second element in the array
if (parts.length === 3) {
return parts[1];
} else {
return null;
}
}

export {
CONNECTOR_TYPE_MAP,
O_AUTH_PROVIDERS,
Expand Down Expand Up @@ -523,4 +535,5 @@ export {
getLLMModelNamesForProfiles,
getFormattedTotalCost,
pollForCompletion,
getDocIdFromKey,
};

0 comments on commit 13a4edc

Please sign in to comment.