diff --git a/.env b/.env index 2a6736ae0cd..8732bcbb06f 100644 --- a/.env +++ b/.env @@ -87,6 +87,7 @@ PUBLIC_APP_COLOR=blue # can be any of tailwind colors: https://tailwindcss.com/d PUBLIC_APP_DATA_SHARING=#set to 1 to enable options & text regarding data sharing PUBLIC_APP_DISCLAIMER=#set to 1 to show a disclaimer on login page +TOOLS = [] # PUBLIC_APP_NAME=HuggingChat # PUBLIC_APP_ASSETS=huggingchat # PUBLIC_APP_COLOR=yellow diff --git a/README.md b/README.md index ca47b163366..961baf0d087 100644 --- a/README.md +++ b/README.md @@ -162,11 +162,11 @@ MODELS=`[ You can change things like the parameters, or customize the preprompt to better suit your needs. You can also add more models by adding more objects to the array, with different preprompts for example. -#### Custom prompt templates: +#### Custom prompt templates By default the prompt is constructed using `userMessageToken`, `assistantMessageToken`, `userMessageEndToken`, `assistantMessageEndToken`, `preprompt` parameters and a series of default templates. -However, these templates can be modified by setting the `chatPromptTemplate` and `webSearchQueryPromptTemplate` parameters. Note that if WebSearch is not enabled, only `chatPromptTemplate` needs to be set. The template language is https://handlebarsjs.com. The templates have access to the model's prompt parameters (`preprompt`, etc.). However, if the templates are specified it is recommended to inline the prompt parameters, as using the references (`{{preprompt}}`) is deprecated. +However, these templates can be modified by setting the `chatPromptTemplate` and `webSearchQueryPromptTemplate` parameters. Note that if WebSearch is not enabled, only `chatPromptTemplate` needs to be set. The template language is . The templates have access to the model's prompt parameters (`preprompt`, etc.). However, if the templates are specified it is recommended to inline the prompt parameters, as using the references (`{{preprompt}}`) is deprecated. For example: @@ -300,6 +300,37 @@ If the model being hosted will be available on multiple servers/instances add th ``` +### Tools + +chat-ui supports two tools currently: + +- `webSearch` +- `textToImage` + +You can enable them by adding the following JSON to your `.env.local`: + +``` +TOOLS = `[ + { + "name" : "textToImage", + "model" : "[model name form the hub here]" + }, + { + "name" : "webSearch", + "key" : { + "type" : "serper", + "apiKey" : "[your key here]" + } + } +]` +``` + +Or a subset of these if you only want to enable some of the tools. + +The web search key `type` can be either `serper` or `serpapi`. + +The `textToImage` model can be [any model from the hub](https://huggingface.co/tasks/text-to-image) that matches the right task as long as the inference endpoint for it is enabled. + ## Deploying to a HF Space Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run. diff --git a/src/lib/components/OpenWebSearchResults.svelte b/src/lib/components/OpenWebSearchResults.svelte index aac5fa54141..8205bafaa27 100644 --- a/src/lib/components/OpenWebSearchResults.svelte +++ b/src/lib/components/OpenWebSearchResults.svelte @@ -1,5 +1,10 @@
{/if} Web search + >Tools
@@ -39,48 +52,38 @@
- {#if webSearchMessages.length === 0} + {#if messagesToDisplay.length === 0}
{:else}
    - {#each webSearchMessages as message} - {#if message.messageType === "update"} -
  1. -
    + {#each messagesToDisplay as message} +
  2. +
    + {#if message.type === "error"} + + {:else}
    -

    - {message.message} -

    -
    - {#if message.args} -

    - {message.args} -

    - {/if} -
  3. - {:else if message.messageType === "error"} -
  4. -
    - -

    - {message.message} -

    -
    - {#if message.args} -

    - {message.args} -

    {/if} -
  5. - {/if} +

    + {message.message} +

    +
+ {#if message.type === "webSearch" && message.args} +

+ {message.args} +

+ {/if} + {/each} {/if} diff --git a/src/lib/components/WebSearchToggle.svelte b/src/lib/components/WebSearchToggle.svelte index 66295e7637c..6475e747b61 100644 --- a/src/lib/components/WebSearchToggle.svelte +++ b/src/lib/components/WebSearchToggle.svelte @@ -3,25 +3,56 @@ import CarbonInformation from "~icons/carbon/information"; import Switch from "./Switch.svelte"; - const toggle = () => ($webSearchParameters.useSearch = !$webSearchParameters.useSearch); + export let tools: { + webSearch: boolean; + textToImage: boolean; + }; + + const toggleWebSearch = () => ($webSearchParameters.useSearch = !$webSearchParameters.useSearch); + const toggleSDXL = () => ($webSearchParameters.useSDXL = !$webSearchParameters.useSDXL);
- -
Search web
-
- + {#if tools.webSearch} +
+ +
Web Search
+
+ +
+

+ When enabled, the request will be completed with relevant context fetched from the web. +

+
+
+
+ {/if} + {#if tools.textToImage}
-

- When enabled, the model will try to complement its answer with information queried from the - web. -

+ +
SDXL Images
+
+ +
+

+ When enabled, the model will try to generate images to go along with the answers. +

+
+
-
+ {/if}
diff --git a/src/lib/components/chat/ChatMessage.svelte b/src/lib/components/chat/ChatMessage.svelte index adf5cf961be..bfcbccf3269 100644 --- a/src/lib/components/chat/ChatMessage.svelte +++ b/src/lib/components/chat/ChatMessage.svelte @@ -2,8 +2,7 @@ import { marked } from "marked"; import markedKatex from "marked-katex-extension"; import type { Message } from "$lib/types/Message"; - import { afterUpdate, createEventDispatcher } from "svelte"; - import { deepestChild } from "$lib/utils/deepestChild"; + import { createEventDispatcher } from "svelte"; import { page } from "$app/stores"; import CodeBlock from "../CodeBlock.svelte"; @@ -17,7 +16,7 @@ import type { Model } from "$lib/types/Model"; import OpenWebSearchResults from "../OpenWebSearchResults.svelte"; - import type { WebSearchUpdate } from "$lib/types/MessageUpdate"; + import type { MessageUpdate, WebSearchUpdate } from "$lib/types/MessageUpdate"; function sanitizeMd(md: string) { let ret = md @@ -49,7 +48,7 @@ export let readOnly = false; export let isTapped = false; - export let webSearchMessages: WebSearchUpdate[]; + export let updateMessages: MessageUpdate[]; const dispatch = createEventDispatcher<{ retry: { content: string; id: Message["id"] }; @@ -57,8 +56,6 @@ }>(); let contentEl: HTMLElement; - let loadingEl: IconLoading; - let pendingTimeout: ReturnType; let isCopied = false; const renderer = new marked.Renderer(); @@ -89,40 +86,19 @@ $: tokens = marked.lexer(sanitizeMd(message.content)); - afterUpdate(() => { - loadingEl?.$destroy(); - clearTimeout(pendingTimeout); - - // Add loading animation to the last message if update takes more than 600ms - if (loading) { - pendingTimeout = setTimeout(() => { - if (contentEl) { - loadingEl = new IconLoading({ - target: deepestChild(contentEl), - props: { classNames: "loading inline ml-2" }, - }); - } - }, 600); - } - }); - - let searchUpdates: WebSearchUpdate[] = []; - - $: searchUpdates = ((webSearchMessages.length > 0 - ? webSearchMessages - : message.updates?.filter(({ type }) => type === "webSearch")) ?? []) as WebSearchUpdate[]; - $: downloadLink = message.from === "user" ? `${$page.url.pathname}/message/${message.id}/prompt` : undefined; let webSearchIsDone = true; $: webSearchIsDone = - searchUpdates.length > 0 && searchUpdates[searchUpdates.length - 1].messageType === "sources"; + updateMessages.length > 0 && updateMessages[updateMessages.length - 1].type === "finalAnswer"; $: webSearchSources = - searchUpdates && - searchUpdates?.filter(({ messageType }) => messageType === "sources")?.[0]?.sources; + updateMessages && + (updateMessages?.filter(({ type }) => type === "webSearch") as WebSearchUpdate[]).filter( + ({ messageType }) => messageType === "sources" + )?.[0]?.sources; $: if (isCopied) { setTimeout(() => { @@ -145,14 +121,14 @@
- {#if searchUpdates && searchUpdates.length > 0} + {#if updateMessages && updateMessages.filter(({ type }) => type === "agent").length > 0} {/if} - {#if !message.content && (webSearchIsDone || (webSearchMessages && webSearchMessages.length === 0))} + {#if !message.content && (webSearchIsDone || (updateMessages && updateMessages.length === 0))} {/if} @@ -160,6 +136,33 @@ class="prose max-w-none dark:prose-invert max-sm:prose-sm prose-headings:font-semibold prose-h1:text-lg prose-h2:text-base prose-h3:text-base prose-pre:bg-gray-800 dark:prose-pre:bg-gray-900" bind:this={contentEl} > + {#if message.files && message.files.length > 0} +
+ {#each message.files as file} +
+ {#if file.mime?.startsWith("image")} + tool output + {:else if file.mime?.startsWith("audio")} + + {/if} + {#if file.model} + Content generated using {file.model} + {/if} +
+ {/each} +
+
+ {/if} {#each tokens as token} {#if token.type === "code"} diff --git a/src/lib/components/chat/ChatMessages.svelte b/src/lib/components/chat/ChatMessages.svelte index e46a41f74bf..3cf6f0b1b77 100644 --- a/src/lib/components/chat/ChatMessages.svelte +++ b/src/lib/components/chat/ChatMessages.svelte @@ -3,17 +3,15 @@ import { snapScrollToBottom } from "$lib/actions/snapScrollToBottom"; import ScrollToBottomBtn from "$lib/components/ScrollToBottomBtn.svelte"; import { tick } from "svelte"; - import { randomUUID } from "$lib/utils/randomUuid"; import type { Model } from "$lib/types/Model"; import type { LayoutData } from "../../../routes/$types"; import ChatIntroduction from "./ChatIntroduction.svelte"; import ChatMessage from "./ChatMessage.svelte"; - import type { WebSearchUpdate } from "$lib/types/MessageUpdate"; + import type { MessageUpdate } from "$lib/types/MessageUpdate"; import { browser } from "$app/environment"; export let messages: Message[]; export let loading: boolean; - export let pending: boolean; export let isAuthor: boolean; export let currentModel: Model; export let settings: LayoutData["settings"]; @@ -22,7 +20,7 @@ let chatContainer: HTMLElement; - export let webSearchMessages: WebSearchUpdate[] = []; + export let updateMessages: MessageUpdate[] = []; async function scrollToBottom() { await tick(); @@ -37,7 +35,7 @@
@@ -48,20 +46,15 @@ {isAuthor} {readOnly} model={currentModel} - webSearchMessages={i === messages.length - 1 ? webSearchMessages : []} + updateMessages={!message.updates && i === messages.length - 1 + ? updateMessages + : message.updates ?? []} on:retry on:vote /> {:else} {/each} - {#if pending} - - {/if}
model.id === currentModel.id); @@ -43,6 +42,7 @@ dispatch("message", message); message = ""; }; + const showTools = settings?.tools.webSearch || settings?.tools.textToImage;
@@ -51,14 +51,13 @@ {/if} { @@ -69,12 +68,12 @@ class="dark:via-gray-80 pointer-events-none absolute inset-x-0 bottom-0 z-0 mx-auto flex w-full max-w-3xl flex-col items-center justify-center bg-gradient-to-t from-white via-white/80 to-white/0 px-3.5 py-4 dark:border-gray-800 dark:from-gray-900 dark:to-gray-900/0 max-md:border-t max-md:bg-white max-md:dark:bg-gray-900 sm:px-5 md:py-8 xl:max-w-4xl [&>*]:pointer-events-auto" >
- {#if settings?.searchEnabled} - + {#if showTools} + {/if} {#if loading} dispatch("stop")} /> {/if} diff --git a/src/lib/server/database.ts b/src/lib/server/database.ts index 0925a8a6a3d..a8b3febd0e3 100644 --- a/src/lib/server/database.ts +++ b/src/lib/server/database.ts @@ -1,5 +1,5 @@ import { MONGODB_URL, MONGODB_DB_NAME, MONGODB_DIRECT_CONNECTION } from "$env/static/private"; -import { MongoClient } from "mongodb"; +import { GridFSBucket, MongoClient } from "mongodb"; import type { Conversation } from "$lib/types/Conversation"; import type { SharedConversation } from "$lib/types/SharedConversation"; import type { WebSearch } from "$lib/types/WebSearch"; @@ -29,6 +29,7 @@ const settings = db.collection("settings"); const users = db.collection("users"); const webSearches = db.collection("webSearches"); const messageEvents = db.collection("messageEvents"); +const bucket = new GridFSBucket(db, { bucketName: "toolOutputs" }); export { client, db }; export const collections = { @@ -39,6 +40,7 @@ export const collections = { users, webSearches, messageEvents, + bucket, }; client.on("open", () => { diff --git a/src/lib/server/generateFromDefaultEndpoint.ts b/src/lib/server/generateFromEndpoint.ts similarity index 94% rename from src/lib/server/generateFromDefaultEndpoint.ts rename to src/lib/server/generateFromEndpoint.ts index b65e8d98100..dde420b3bea 100644 --- a/src/lib/server/generateFromDefaultEndpoint.ts +++ b/src/lib/server/generateFromEndpoint.ts @@ -11,8 +11,9 @@ interface Parameters { max_new_tokens: number; stop: string[]; } -export async function generateFromDefaultEndpoint( +export async function generateFromEndpoint( prompt: string, + model?: typeof defaultModel, parameters?: Partial ): Promise { const newParameters = { @@ -21,7 +22,7 @@ export async function generateFromDefaultEndpoint( return_full_text: false, }; - const randomEndpoint = modelEndpoint(defaultModel); + const randomEndpoint = modelEndpoint(model ?? defaultModel); const abortController = new AbortController(); diff --git a/src/lib/server/summarize.ts b/src/lib/server/summarize.ts index 3398cebd633..76161c230cc 100644 --- a/src/lib/server/summarize.ts +++ b/src/lib/server/summarize.ts @@ -1,5 +1,5 @@ import { buildPrompt } from "$lib/buildPrompt"; -import { generateFromDefaultEndpoint } from "$lib/server/generateFromDefaultEndpoint"; +import { generateFromEndpoint } from "$lib/server/generateFromEndpoint"; import { defaultModel } from "$lib/server/models"; export async function summarize(prompt: string) { @@ -12,7 +12,7 @@ export async function summarize(prompt: string) { model: defaultModel, }); - const generated_text = await generateFromDefaultEndpoint(summaryPrompt).catch((e) => { + const generated_text = await generateFromEndpoint(summaryPrompt).catch((e) => { console.error(e); return null; }); diff --git a/src/lib/server/tools.ts b/src/lib/server/tools.ts new file mode 100644 index 00000000000..d7fc9ac4260 --- /dev/null +++ b/src/lib/server/tools.ts @@ -0,0 +1,60 @@ +import { SERPAPI_KEY, SERPER_API_KEY, TOOLS } from "$env/static/private"; +import { z } from "zod"; + +const webSearchTool = z.object({ + name: z.literal("webSearch"), + key: z.union([ + z.object({ + type: z.literal("serpapi"), + apiKey: z.string().min(1).default(SERPAPI_KEY), + }), + z.object({ + type: z.literal("serper"), + apiKey: z.string().min(1).default(SERPER_API_KEY), + }), + ]), +}); + +const textToImageTool = z.object({ + name: z.literal("textToImage"), + model: z.string().min(1).default("stabilityai/stable-diffusion-xl-base-1.0"), + parameters: z.optional( + z.object({ + negative_prompt: z.string().optional(), + height: z.number().optional(), + width: z.number().optional(), + num_inference_steps: z.number().optional(), + guidance_scale: z.number().optional(), + }) + ), +}); + +const toolsDefinition = z.array(z.discriminatedUnion("name", [webSearchTool, textToImageTool])); + +export const tools = toolsDefinition.parse(JSON.parse(TOOLS)); + +// check if SERPAPI_KEY or SERPER_API_KEY are defined, and if so append them to the tools + +if (!tools.some((tool) => tool.name === "webSearch")) { + if (SERPAPI_KEY) { + tools.push({ + name: "webSearch", + key: { + type: "serpapi", + apiKey: SERPAPI_KEY, + }, + }); + } else if (SERPER_API_KEY) { + tools.push({ + name: "webSearch", + key: { + type: "serper", + apiKey: SERPER_API_KEY, + }, + }); + } +} + +export type Tool = z.infer[number]; +export type WebSearchTool = z.infer; +export type TextToImageTool = z.infer; diff --git a/src/lib/server/tools/uploadFile.ts b/src/lib/server/tools/uploadFile.ts new file mode 100644 index 00000000000..fa028c5ae04 --- /dev/null +++ b/src/lib/server/tools/uploadFile.ts @@ -0,0 +1,23 @@ +import type { Conversation } from "$lib/types/Conversation"; +import { sha256 } from "$lib/utils/sha256"; +import type { Tool } from "@huggingface/agents/src/types"; +import { collections } from "../database"; + +export async function uploadFile(file: Blob, conv: Conversation, tool?: Tool): Promise { + const sha = await sha256(await file.text()); + const filename = `${conv._id}-${sha}`; + + const upload = collections.bucket.openUploadStream(filename, { + metadata: { conversation: conv._id.toString(), model: tool?.model, mime: tool?.mime }, + }); + + upload.write((await file.arrayBuffer()) as unknown as Buffer); + upload.end(); + + // only return the filename when upload throws a finish event or a 10s time out occurs + return new Promise((resolve, reject) => { + upload.once("finish", () => resolve(filename)); + upload.once("error", reject); + setTimeout(() => reject(new Error("Upload timed out")), 10000); + }); +} diff --git a/src/lib/server/websearch/generateQuery.ts b/src/lib/server/websearch/generateQuery.ts index d812bff4d24..2d26a4ee9dd 100644 --- a/src/lib/server/websearch/generateQuery.ts +++ b/src/lib/server/websearch/generateQuery.ts @@ -1,6 +1,6 @@ import type { Message } from "$lib/types/Message"; import { format } from "date-fns"; -import { generateFromDefaultEndpoint } from "../generateFromDefaultEndpoint"; +import { generateFromEndpoint } from "../generateFromEndpoint"; import { defaultModel } from "../models"; export async function generateQuery(messages: Message[]) { @@ -13,7 +13,7 @@ export async function generateQuery(messages: Message[]) { previousMessages: previousUserMessages.map(({ content }) => content).join(" "), currentDate, }); - const searchQuery = await generateFromDefaultEndpoint(promptSearchQuery).then((query) => { + const searchQuery = await generateFromEndpoint(promptSearchQuery).then((query) => { // example of generating google query: // case 1 // user: tell me what happened yesterday diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index e0c62264615..d57f949e876 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -63,7 +63,7 @@ export async function runWebSearch( text = await parseWeb(link); appendUpdate("Browsing webpage", [link]); } catch (e) { - console.error(`Error parsing webpage "${link}"`, e); + // console.error(`Error parsing webpage "${link}"`, e); } const MAX_N_CHUNKS = 100; const texts = chunk(text, CHUNK_CAR_LEN).slice(0, MAX_N_CHUNKS); diff --git a/src/lib/server/websearch/searchWeb.ts b/src/lib/server/websearch/searchWeb.ts index eab3c3d5f7e..60af1b865a8 100644 --- a/src/lib/server/websearch/searchWeb.ts +++ b/src/lib/server/websearch/searchWeb.ts @@ -1,14 +1,14 @@ -import { SERPAPI_KEY, SERPER_API_KEY } from "$env/static/private"; - import { getJson } from "serpapi"; import type { GoogleParameters } from "serpapi"; +import { tools, type WebSearchTool } from "../tools"; +const webSearchTool = tools.find((tool) => tool.name === "webSearch") as WebSearchTool; // Show result as JSON export async function searchWeb(query: string) { - if (SERPER_API_KEY) { + if (webSearchTool.key.type === "serper") { return await searchWebSerper(query); } - if (SERPAPI_KEY) { + if (webSearchTool.key.type === "serpapi") { return await searchWebSerpApi(query); } throw new Error("No Serper.dev or SerpAPI key found"); @@ -25,7 +25,7 @@ export async function searchWebSerper(query: string) { method: "POST", body: JSON.stringify(params), headers: { - "x-api-key": SERPER_API_KEY, + "x-api-key": webSearchTool.key.apiKey, "Content-type": "application/json; charset=UTF-8", }, }); @@ -51,7 +51,7 @@ export async function searchWebSerpApi(query: string) { hl: "en", gl: "us", google_domain: "google.com", - api_key: SERPAPI_KEY, + api_key: webSearchTool.key.apiKey, } satisfies GoogleParameters; // Show result as JSON diff --git a/src/lib/stores/webSearchParameters.ts b/src/lib/stores/webSearchParameters.ts index fd088a60621..868eeb92d31 100644 --- a/src/lib/stores/webSearchParameters.ts +++ b/src/lib/stores/webSearchParameters.ts @@ -1,9 +1,11 @@ import { writable } from "svelte/store"; export interface WebSearchParameters { useSearch: boolean; + useSDXL: boolean; nItems: number; } export const webSearchParameters = writable({ useSearch: false, + useSDXL: false, nItems: 5, }); diff --git a/src/lib/types/FileMetaData.ts b/src/lib/types/FileMetaData.ts new file mode 100644 index 00000000000..6d681ee9101 --- /dev/null +++ b/src/lib/types/FileMetaData.ts @@ -0,0 +1,7 @@ +export interface FileMetaData { + convId: string; + sha256: string; + createdAt: Date; + model: string; + tool: string; +} diff --git a/src/lib/types/Message.ts b/src/lib/types/Message.ts index 2d092c10f0b..9f1cce8d129 100644 --- a/src/lib/types/Message.ts +++ b/src/lib/types/Message.ts @@ -2,6 +2,12 @@ import type { MessageUpdate } from "./MessageUpdate"; import type { Timestamps } from "./Timestamps"; import type { WebSearch } from "./WebSearch"; +export interface File { + sha256: string; + model?: string; + mime?: string; +} + export type Message = Partial & { from: "user" | "assistant"; id: ReturnType; @@ -10,4 +16,5 @@ export type Message = Partial & { webSearchId?: WebSearch["_id"]; // legacy version webSearch?: WebSearch; score?: -1 | 0 | 1; + files?: File[]; // filenames }; diff --git a/src/lib/types/MessageUpdate.ts b/src/lib/types/MessageUpdate.ts index 613b92e05b8..707039cc911 100644 --- a/src/lib/types/MessageUpdate.ts +++ b/src/lib/types/MessageUpdate.ts @@ -1,4 +1,6 @@ +import type { File } from "./Message"; import type { WebSearchSource } from "./WebSearch"; +import type { Update } from "@huggingface/agents/src/types"; export type FinalAnswer = { type: "finalAnswer"; @@ -10,12 +12,9 @@ export type TextStreamUpdate = { token: string; }; -export type AgentUpdate = { +export interface AgentUpdate extends Update { type: "agent"; - agent: string; - content: string; - binary?: Blob; -}; +} export type WebSearchUpdate = { type: "webSearch"; @@ -31,9 +30,21 @@ export type StatusUpdate = { message?: string; }; +export type FileUpdate = { + type: "file"; + file: File; +}; + +export type ErrorUpdate = { + type: "error"; + message: string; +}; + export type MessageUpdate = | FinalAnswer | TextStreamUpdate | AgentUpdate | WebSearchUpdate - | StatusUpdate; + | StatusUpdate + | FileUpdate + | ErrorUpdate; diff --git a/src/routes/+layout.server.ts b/src/routes/+layout.server.ts index ba71c157875..51dacd7ee56 100644 --- a/src/routes/+layout.server.ts +++ b/src/routes/+layout.server.ts @@ -6,7 +6,8 @@ import { UrlDependency } from "$lib/types/UrlDependency"; import { defaultModel, models, oldModels, validateModel } from "$lib/server/models"; import { authCondition, requiresUser } from "$lib/server/auth"; import { DEFAULT_SETTINGS } from "$lib/types/Settings"; -import { SERPAPI_KEY, SERPER_API_KEY, MESSAGES_BEFORE_LOGIN } from "$env/static/private"; +import { MESSAGES_BEFORE_LOGIN } from "$env/static/private"; +import { tools } from "$lib/server/tools"; export const load: LayoutServerLoad = async ({ locals, depends, url }) => { const { conversations } = collections; @@ -61,7 +62,10 @@ export const load: LayoutServerLoad = async ({ locals, depends, url }) => { DEFAULT_SETTINGS.shareConversationsWithModelAuthors, ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null, activeModel: settings?.activeModel ?? DEFAULT_SETTINGS.activeModel, - searchEnabled: !!(SERPAPI_KEY || SERPER_API_KEY), + tools: { + webSearch: tools.some((tool) => tool.name === "webSearch"), + textToImage: tools.some((tool) => tool.name === "textToImage"), + }, customPrompts: settings?.customPrompts ?? {}, }, models: models.map((model) => ({ diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index 9de3d10aaf5..9dde7b604d5 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -14,7 +14,7 @@ import { webSearchParameters } from "$lib/stores/webSearchParameters"; import type { Message } from "$lib/types/Message"; import { PUBLIC_APP_DISCLAIMER } from "$env/static/public"; - import type { MessageUpdate, WebSearchUpdate } from "$lib/types/MessageUpdate"; + import type { MessageUpdate } from "$lib/types/MessageUpdate"; export let data; @@ -22,7 +22,7 @@ let lastLoadedMessages = data.messages; let isAborted = false; - let webSearchMessages: WebSearchUpdate[] = []; + let updateMessages: MessageUpdate[] = []; // Since we modify the messages array locally, we don't want to reset it if an old version is passed $: if (data.messages !== lastLoadedMessages) { @@ -31,7 +31,6 @@ } let loading = false; - let pending = false; let loginRequired = false; // this function is used to send new message to the backends @@ -41,7 +40,6 @@ try { isAborted = false; loading = true; - pending = true; // first we check if the messageId already exists, indicating a retry @@ -58,8 +56,20 @@ { from: "user", content: message, id: messageId }, ]; + messages = [...messages, { from: "assistant", id: randomUUID(), content: "", files: [] }]; + const responseId = randomUUID(); + const toolsToBeUsed = []; + + if ($webSearchParameters.useSearch) { + toolsToBeUsed.push("webSearch"); + } + + if ($webSearchParameters.useSDXL) { + toolsToBeUsed.push("textToImage"); + } + const response = await fetch(`${base}/conversation/${$page.params.id}`, { method: "POST", headers: { "Content-Type": "application/json" }, @@ -68,7 +78,7 @@ id: messageId, response_id: responseId, is_retry: isRetry, - web_search: $webSearchParameters.useSearch, + tools: toolsToBeUsed, }), }); @@ -83,7 +93,7 @@ // this is a bit ugly // we read the stream until we get the final answer - while (finalAnswer === "") { + while (finalAnswer === "" && !isAborted) { // await new Promise((r) => setTimeout(r, 25)); // check for abort @@ -110,24 +120,19 @@ try { let update = JSON.parse(el) as MessageUpdate; if (update.type === "finalAnswer") { + updateMessages = [...updateMessages, update]; finalAnswer = update.text; - invalidate(UrlDependency.Conversation); } else if (update.type === "stream") { - pending = false; - let lastMessage = messages[messages.length - 1]; - - if (lastMessage.from !== "assistant") { - messages = [ - ...messages, - { from: "assistant", id: randomUUID(), content: update.token }, - ]; - } else { - lastMessage.content += update.token; - messages = [...messages]; - } + lastMessage.content += update.token; + messages = [...messages]; } else if (update.type === "webSearch") { - webSearchMessages = [...webSearchMessages, update]; + updateMessages = [...updateMessages, update]; + } else if (update.type === "agent") { + updateMessages = [...updateMessages, update]; + } else if (update.type === "file") { + messages[messages.length - 1].files?.push(update.file); + messages = [...messages]; } } catch (parseError) { // in case of parsing error we wait for the next message @@ -138,8 +143,7 @@ } // reset the websearchmessages - webSearchMessages = []; - + updateMessages = []; await invalidate(UrlDependency.ConversationList); } catch (err) { if (err instanceof Error && err.message.includes("overloaded")) { @@ -154,7 +158,9 @@ console.error(err); } finally { loading = false; - pending = false; + // wait 500ms + await new Promise((r) => setTimeout(r, 500)); + invalidate(UrlDependency.Conversation); } } @@ -216,9 +222,8 @@ writeMessage(event.detail)} on:retry={(event) => writeMessage(event.detail.content, event.detail.id)} on:vote={(event) => voteMessage(event.detail.score, event.detail.id)} diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index c3d1b8d0486..aa9f4c6804a 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -1,24 +1,24 @@ import { HF_ACCESS_TOKEN, MESSAGES_BEFORE_LOGIN, RATE_LIMIT } from "$env/static/private"; -import { buildPrompt } from "$lib/buildPrompt"; -import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken"; import { authCondition, requiresUser } from "$lib/server/auth"; import { collections } from "$lib/server/database"; import { modelEndpoint } from "$lib/server/modelEndpoint"; import { models } from "$lib/server/models"; import { ERROR_MESSAGES } from "$lib/stores/errors"; import type { Message } from "$lib/types/Message"; -import { trimPrefix } from "$lib/utils/trimPrefix"; -import { trimSuffix } from "$lib/utils/trimSuffix"; import { textGenerationStream } from "@huggingface/inference"; import { error } from "@sveltejs/kit"; import { ObjectId } from "mongodb"; import { z } from "zod"; import { AwsClient } from "aws4fetch"; -import type { MessageUpdate } from "$lib/types/MessageUpdate"; +import type { AgentUpdate, MessageUpdate } from "$lib/types/MessageUpdate"; import { runWebSearch } from "$lib/server/websearch/runWebSearch"; -import type { WebSearch } from "$lib/types/WebSearch"; import { abortedGenerations } from "$lib/server/abortedGenerations"; import { summarize } from "$lib/server/summarize"; +import type { TextGenerationStreamOutput } from "@huggingface/inference"; +import { HfChatAgent } from "@huggingface/agents"; +import { uploadFile } from "$lib/server/tools/uploadFile.js"; +import type { Tool } from "@huggingface/agents/src/types.js"; +import { tools as toolSettings, type TextToImageTool } from "$lib/server/tools.js"; export async function POST({ request, fetch, locals, params, getClientAddress }) { const id = z.string().parse(params.id); @@ -84,20 +84,20 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) response_id: responseId, id: messageId, is_retry, - web_search: webSearch, + tools, } = z .object({ inputs: z.string().trim().min(1), id: z.optional(z.string().uuid()), response_id: z.optional(z.string().uuid()), is_retry: z.optional(z.boolean()), - web_search: z.optional(z.boolean()), + tools: z.array(z.string()), }) .parse(json); // get the list of messages // while checking for retries - let messages = (() => { + const messages = (() => { if (is_retry && messageId) { // if the message is a retry, replace the message and remove the messages after it let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId); @@ -122,69 +122,97 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) ]; })() satisfies Message[]; + // save user prompt + await collections.conversations.updateOne( + { + _id: convId, + }, + { + $set: { + messages, + title: (await summarize(newPrompt)) ?? conv.title, + updatedAt: new Date(), + }, + } + ); + + // fetch the endpoint + const randomEndpoint = modelEndpoint(model); + + let usedFetch = fetch; + + if (randomEndpoint.host === "sagemaker") { + const aws = new AwsClient({ + accessKeyId: randomEndpoint.accessKey, + secretAccessKey: randomEndpoint.secretKey, + sessionToken: randomEndpoint.sessionToken, + service: "sagemaker", + }); + + usedFetch = aws.fetch.bind(aws) as typeof fetch; + } + // we now build the stream const stream = new ReadableStream({ async start(controller) { const updates: MessageUpdate[] = []; - function update(newUpdate: MessageUpdate) { if (newUpdate.type !== "stream") { updates.push(newUpdate); } - controller.enqueue(JSON.stringify(newUpdate) + "\n"); + try { + controller.enqueue(JSON.stringify(newUpdate) + "\n"); + } catch (e) { + try { + stream.cancel(); + } catch (f) { + console.error(f); + // ignore + } + } } - update({ type: "status", status: "started" }); - - let webSearchResults: WebSearch | undefined; + function getStream(inputs: string) { + if (!conv) { + throw new Error("Conversation not found"); + } - if (webSearch) { - webSearchResults = await runWebSearch(conv, newPrompt, update); + return textGenerationStream( + { + inputs, + parameters: { + ...models.find((m) => m.id === conv.model)?.parameters, + return_full_text: false, + max_new_tokens: 4000, + }, + model: randomEndpoint.url, + accessToken: randomEndpoint.host === "sagemaker" ? undefined : HF_ACCESS_TOKEN, + }, + { + use_cache: false, + fetch: usedFetch, + } + ); } - // we can now build the prompt using the messages - const prompt = await buildPrompt({ - messages, - model, - webSearch: webSearchResults, - preprompt: settings?.customPrompts?.[model.id] ?? model.preprompt, - locals: locals, + messages.push({ + from: "assistant", + content: "", + updates: updates, + files: [], + id: (responseId as Message["id"]) || crypto.randomUUID(), + createdAt: new Date(), + updatedAt: new Date(), }); - // fetch the endpoint - const randomEndpoint = modelEndpoint(model); - - let usedFetch = fetch; - - if (randomEndpoint.host === "sagemaker") { - const aws = new AwsClient({ - accessKeyId: randomEndpoint.accessKey, - secretAccessKey: randomEndpoint.secretKey, - sessionToken: randomEndpoint.sessionToken, - service: "sagemaker", - }); - - usedFetch = aws.fetch.bind(aws) as typeof fetch; - } + const lastMessage = messages[messages.length - 1]; async function saveLast(generated_text: string) { if (!conv) { throw new Error("Conversation not found"); } - const lastMessage = messages[messages.length - 1]; - if (lastMessage) { - // We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text - if (generated_text.startsWith(prompt)) { - generated_text = generated_text.slice(prompt.length); - } - - generated_text = trimSuffix( - trimPrefix(generated_text, "<|startoftext|>"), - PUBLIC_SEP_TOKEN - ).trimEnd(); - // remove the stop tokens for (const stop of [...(model?.parameters?.stop ?? []), "<|endoftext|>"]) { if (generated_text.endsWith(stop)) { @@ -206,73 +234,150 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) } ); - update({ - type: "finalAnswer", - text: generated_text, - }); + update({ type: "finalAnswer", text: generated_text }); } } - const tokenStream = textGenerationStream( - { - parameters: { - ...models.find((m) => m.id === conv.model)?.parameters, - return_full_text: false, - }, - model: randomEndpoint.url, - inputs: prompt, - accessToken: randomEndpoint.host === "sagemaker" ? undefined : HF_ACCESS_TOKEN, - }, - { - use_cache: false, - fetch: usedFetch, - } - ); - - for await (const output of tokenStream) { - // if not generated_text is here it means the generation is not done + const streamCallback = async (output: TextGenerationStreamOutput) => { if (!output.generated_text) { // else we get the next token if (!output.token.special) { - const lastMessage = messages[messages.length - 1]; + // if the last message is not from assistant, it means this is the first token + const date = abortedGenerations.get(convId.toString()); + + if (date && date > promptedAt) { + saveLast(lastMessage.content); + } + + if (!output) { + return; + } + + // otherwise we just concatenate tokens + lastMessage.content += output.token.text; + update({ type: "stream", token: output.token.text, }); - - // if the last message is not from assistant, it means this is the first token - if (lastMessage?.from !== "assistant") { - // so we create a new message - messages = [ - ...messages, - // id doesn't match the backend id but it's not important for assistant messages - // First token has a space at the beginning, trim it - { - from: "assistant", - content: output.token.text.trimStart(), - webSearch: webSearchResults, - updates: updates, - id: (responseId as Message["id"]) || crypto.randomUUID(), - createdAt: new Date(), - updatedAt: new Date(), - }, - ]; - } else { - const date = abortedGenerations.get(convId.toString()); - if (date && date > promptedAt) { - saveLast(lastMessage.content); - } - if (!output) { - break; - } - - // otherwise we just concatenate tokens - lastMessage.content += output.token.text; - } } - } else { - saveLast(output.generated_text); } + }; + + const listTools: Tool[] = []; + + if (toolSettings.some((t) => t.name === "webSearch")) { + const webSearchTool: Tool = { + name: "webSearch", + description: + "This tool can be used to search the web for extra information. It will return the most relevant paragraphs from the web", + examples: [ + { + prompt: "What are the best restaurants in Paris?", + code: '{"tool" : "imageToText", "input" : "What are the best restaurants in Paris?"}', + tools: ["webSearch"], + }, + { + prompt: "Who is the president of the United States?", + code: '{"tool" : "imageToText", "input" : "Who is the president of the United States?"}', + tools: ["webSearch"], + }, + ], + call: async (input, _) => { + const data = await input; + if (typeof data !== "string") throw "Input must be a string."; + + const results = await runWebSearch(conv, data, update); + return results.context; + }, + }; + + listTools.push(webSearchTool); + } + + if (toolSettings.some((t) => t.name === "textToImage")) { + const toolParameters = toolSettings.find( + (t) => t.name === "textToImage" + ) as TextToImageTool; + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const SDXLTool: Tool = { + name: "textToImage", + description: + "This tool can be used to generate an image from text. It will return the image.", + mime: "image/jpeg", + model: "https://huggingface.co/" + toolParameters.model, + examples: [ + { + prompt: "Generate an image of a cat wearing a top hat", + code: '{"tool" : "textToImage", "input" : "a cat wearing a top hat"}', + tools: ["textToImage"], + }, + { + prompt: "Draw a brown dog on a beach", + code: '{"tool" : "textToImage", "input" : "drawing of a brown dog on a beach"}', + tools: ["textToImage"], + }, + ], + call: async (input, inference) => { + const data = await input; + if (typeof data !== "string") throw "Input must be a string."; + + const imageBase = await inference.textToImage( + { + inputs: data, + model: toolParameters.model, + parameters: toolParameters.parameters, + }, + { wait_for_model: true } + ); + return imageBase; + }, + }; + + listTools.push(SDXLTool); + } + + const agent = new HfChatAgent({ + accessToken: HF_ACCESS_TOKEN, + llm: getStream, + chatFormat: (inputs: { messages: Message[] }) => + model.chatPromptRender({ + messages: inputs.messages, + preprompt: settings?.customPrompts?.[model.id] ?? model.preprompt, + }), + callbacks: { + onFile: async (file, tool) => { + const filename = await uploadFile(file, conv, tool); + + const fileObject = { + sha256: filename.split("-")[1], + model: tool?.model, + mime: tool?.mime, + }; + lastMessage.files?.push(fileObject); + update({ type: "file", file: fileObject }); + }, + onUpdate: async (agentUpdate) => { + update({ ...agentUpdate, type: "agent" } satisfies AgentUpdate); + }, + onStream: streamCallback, + onFinalAnswer: async (answer) => { + update({ type: "finalAnswer", text: answer }); + saveLast(answer); + }, + onError: async (errorUpdate) => { + update({ type: "error", message: errorUpdate.message }); + }, + }, + chatHistory: messages, + tools: listTools.filter((t) => tools.includes(t.name)), + }); + + try { + await agent.chat(newPrompt); + } catch (e) { + console.error(e); + return new Error((e as Error).message); } }, async cancel() { diff --git a/src/routes/conversation/[id]/output/[sha256]/+server.ts b/src/routes/conversation/[id]/output/[sha256]/+server.ts new file mode 100644 index 00000000000..1ef6f21ede5 --- /dev/null +++ b/src/routes/conversation/[id]/output/[sha256]/+server.ts @@ -0,0 +1,60 @@ +import { authCondition } from "$lib/server/auth"; +import { collections } from "$lib/server/database"; +import { error } from "@sveltejs/kit"; +import { ObjectId } from "mongodb"; +import { z } from "zod"; +import type { RequestHandler } from "./$types"; + +export const GET: RequestHandler = async ({ locals, params }) => { + const convId = new ObjectId(z.string().parse(params.id)); + const sha256 = z.string().parse(params.sha256); + + const userId = locals.user?._id ?? locals.sessionId; + + // check user + if (!userId) { + throw error(401, "Unauthorized"); + } + + // check if the user has access to the conversation + const conv = await collections.conversations.findOne({ + _id: convId, + ...authCondition(locals), + }); + + if (!conv) { + throw error(404, "Conversation not found"); + } + + const fileId = collections.bucket.find({ filename: `${convId}-${sha256}` }); + let mime; + + const content = await fileId.next().then(async (file) => { + if (!file) { + throw error(404, "File not found"); + } + + if (file.metadata?.conversation !== convId.toString()) { + throw error(403, "You don't have access to this file."); + } + + mime = file.metadata?.mime; + + const fileStream = collections.bucket.openDownloadStream(file._id); + + const fileBuffer = await new Promise((resolve, reject) => { + const chunks: Uint8Array[] = []; + fileStream.on("data", (chunk) => chunks.push(chunk)); + fileStream.on("error", reject); + fileStream.on("end", () => resolve(Buffer.concat(chunks))); + }); + + return fileBuffer; + }); + + return new Response(content, { + headers: { + "Content-Type": mime ?? "application/octet-stream", + }, + }); +}; diff --git a/src/routes/conversation/[id]/upload/+server.ts b/src/routes/conversation/[id]/upload/+server.ts new file mode 100644 index 00000000000..e57db0a952d --- /dev/null +++ b/src/routes/conversation/[id]/upload/+server.ts @@ -0,0 +1,39 @@ +import { authCondition } from "$lib/server/auth"; +import { collections } from "$lib/server/database"; +import { uploadFile } from "$lib/server/tools/uploadFile"; +import { error } from "@sveltejs/kit"; +import { ObjectId } from "mongodb"; +import { z } from "zod"; +import type { RequestHandler } from "../$types"; + +export const POST: RequestHandler = async ({ locals, params, request }) => { + const convId = new ObjectId(z.string().parse(params.id)); + const data = await request.formData(); + + const userId = locals.user?._id ?? locals.sessionId; + + // check user + if (!userId) { + throw error(401, "Unauthorized"); + } + + // check if the user has access to the conversation + const conv = await collections.conversations.findOne({ + _id: convId, + ...authCondition(locals), + }); + + if (!conv) { + throw error(404, "Conversation not found"); + } + + const file = data.get("file") as File; + + if (!file) { + throw error(400, "No file provided"); + } + + const filename = await uploadFile(file, conv); + + return new Response(filename); +};