-
Notifications
You must be signed in to change notification settings - Fork 859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support set Memory from UI #852
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import datetime | ||
import json | ||
import logging | ||
import re | ||
from typing import Any, List | ||
|
||
|
@@ -23,8 +24,11 @@ | |
from app.models.tools import DatasourceInput | ||
from app.tools import TOOL_TYPE_MAPPING, create_pydantic_model_from_object, create_tool | ||
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool | ||
from app.utils.helpers import get_first_non_null | ||
from app.utils.llm import LLM_MAPPING | ||
from prisma.models import LLM, Agent, AgentDatasource, AgentTool | ||
from prisma.models import LLM, Agent, AgentDatasource, AgentTool, MemoryDb | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
DEFAULT_PROMPT = ( | ||
"You are a helpful AI Assistant, answer the users questions to " | ||
|
@@ -193,33 +197,48 @@ async def _get_prompt(self, agent: Agent) -> str: | |
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}" | ||
return SystemMessage(content=content) | ||
|
||
async def _get_memory(self) -> List: | ||
memory_type = config("MEMORY", "motorhead") | ||
if memory_type == "redis": | ||
async def _get_memory(self, memory_db: MemoryDb) -> List: | ||
logger.debug(f"Use memory config: {memory_db}") | ||
if memory_db is None: | ||
memory_provider = config("MEMORY") | ||
options = {} | ||
else: | ||
memory_provider = memory_db.provider | ||
options = memory_db.options | ||
if memory_provider == "REDIS" or memory_provider == "redis": | ||
memory = ConversationBufferWindowMemory( | ||
chat_memory=RedisChatMessageHistory( | ||
session_id=( | ||
f"{self.agent_id}-{self.session_id}" | ||
if self.session_id | ||
else f"{self.agent_id}" | ||
), | ||
url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"), | ||
url=get_first_non_null( | ||
options.get("REDIS_MEMORY_URL"), | ||
config("REDIS_MEMORY_URL", "redis://localhost:6379/0"), | ||
), | ||
key_prefix="superagent:", | ||
), | ||
memory_key="chat_history", | ||
return_messages=True, | ||
output_key="output", | ||
k=config("REDIS_MEMORY_WINDOW", 10), | ||
k=get_first_non_null( | ||
options.get("REDIS_MEMORY_WINDOW"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If one wants to define different REDIS_MEMORY_WINDOW per each agent. We can't achieve it by defining global memory and use it this way. Is it something we should care about? @homanp There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @elisalimli @homanp Perhaps |
||
config("REDIS_MEMORY_WINDOW", 10), | ||
), | ||
) | ||
else: | ||
elif memory_provider == "MOTORHEAD" or memory_provider == "motorhead": | ||
memory = MotorheadMemory( | ||
session_id=( | ||
f"{self.agent_id}-{self.session_id}" | ||
if self.session_id | ||
else f"{self.agent_id}" | ||
), | ||
memory_key="chat_history", | ||
url=config("MEMORY_API_URL"), | ||
url=get_first_non_null( | ||
options.get("MEMORY_API_URL"), | ||
config("MEMORY_API_URL"), | ||
), | ||
return_messages=True, | ||
output_key="output", | ||
) | ||
|
@@ -235,7 +254,7 @@ async def get_agent(self): | |
agent_tools=self.agent_config.tools, | ||
) | ||
prompt = await self._get_prompt(agent=self.agent_config) | ||
memory = await self._get_memory() | ||
memory = await self._get_memory(memory_db=self.memory_config) | ||
|
||
if len(tools) > 0: | ||
agent = initialize_agent( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import json | ||
|
||
import segment.analytics as analytics | ||
from decouple import config | ||
from fastapi import APIRouter, Depends | ||
|
||
from app.models.request import MemoryDb as MemoryDbRequest | ||
from app.models.response import MemoryDb as MemoryDbResponse | ||
from app.models.response import MemoryDbList as MemoryDbListResponse | ||
from app.utils.api import get_current_api_user, handle_exception | ||
from app.utils.prisma import prisma | ||
from prisma import Json | ||
|
||
SEGMENT_WRITE_KEY = config("SEGMENT_WRITE_KEY", None) | ||
|
||
router = APIRouter() | ||
analytics.write_key = SEGMENT_WRITE_KEY | ||
|
||
|
||
@router.post( | ||
"/memory-db", | ||
name="create", | ||
description="Create a new Memory Database", | ||
response_model=MemoryDbResponse, | ||
) | ||
async def create(body: MemoryDbRequest, api_user=Depends(get_current_api_user)): | ||
"""Endpoint for creating a Memory Database""" | ||
if SEGMENT_WRITE_KEY: | ||
analytics.track(api_user.id, "Created Memory Database") | ||
|
||
data = await prisma.memorydb.create( | ||
{ | ||
**body.dict(), | ||
"apiUserId": api_user.id, | ||
"options": json.dumps(body.options), | ||
} | ||
) | ||
data.options = json.dumps(data.options) | ||
return {"success": True, "data": data} | ||
|
||
|
||
@router.get( | ||
"/memory-dbs", | ||
name="list", | ||
description="List all Memory Databases", | ||
response_model=MemoryDbListResponse, | ||
) | ||
async def list(api_user=Depends(get_current_api_user)): | ||
"""Endpoint for listing all Memory Databases""" | ||
try: | ||
data = await prisma.memorydb.find_many( | ||
where={"apiUserId": api_user.id}, order={"createdAt": "desc"} | ||
) | ||
# Convert options to string | ||
for item in data: | ||
item.options = json.dumps(item.options) | ||
return {"success": True, "data": data} | ||
except Exception as e: | ||
handle_exception(e) | ||
|
||
|
||
@router.get( | ||
"/memory-dbs/{memory_db_id}", | ||
name="get", | ||
description="Get a single Memory Database", | ||
response_model=MemoryDbResponse, | ||
) | ||
async def get(memory_db_id: str, api_user=Depends(get_current_api_user)): | ||
"""Endpoint for getting a single Memory Database""" | ||
try: | ||
data = await prisma.memorydb.find_first( | ||
where={"id": memory_db_id, "apiUserId": api_user.id} | ||
) | ||
data.options = json.dumps(data.options) | ||
return {"success": True, "data": data} | ||
except Exception as e: | ||
handle_exception(e) | ||
|
||
|
||
@router.patch( | ||
"/memory-dbs/{memory_db_id}", | ||
name="update", | ||
description="Patch a Memory Database", | ||
response_model=MemoryDbResponse, | ||
) | ||
async def update( | ||
memory_db_id: str, body: MemoryDbRequest, api_user=Depends(get_current_api_user) | ||
): | ||
"""Endpoint for patching a Memory Database""" | ||
try: | ||
if SEGMENT_WRITE_KEY: | ||
analytics.track(api_user.id, "Updated Memory Database") | ||
data = await prisma.memorydb.update( | ||
where={"id": memory_db_id}, | ||
data={ | ||
**body.dict(exclude_unset=True), | ||
"apiUserId": api_user.id, | ||
"options": Json(body.options), | ||
}, | ||
) | ||
data.options = json.dumps(data.options) | ||
return {"success": True, "data": data} | ||
except Exception as e: | ||
handle_exception(e) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
-- CreateEnum | ||
CREATE TYPE "MemoryDbProvider" AS ENUM ('MOTORHEAD', 'REDIS'); | ||
|
||
-- AlterTable | ||
ALTER TABLE "Agent" ADD COLUMN "memory" "MemoryDbProvider" DEFAULT 'MOTORHEAD'; | ||
|
||
-- CreateTable | ||
CREATE TABLE "MemoryDb" ( | ||
"id" TEXT NOT NULL, | ||
"provider" "MemoryDbProvider" NOT NULL DEFAULT 'MOTORHEAD', | ||
"options" JSONB NOT NULL, | ||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
"updatedAt" TIMESTAMP(3) NOT NULL, | ||
"apiUserId" TEXT NOT NULL, | ||
|
||
CONSTRAINT "MemoryDb_pkey" PRIMARY KEY ("id") | ||
); | ||
|
||
-- AddForeignKey | ||
ALTER TABLE "MemoryDb" ADD CONSTRAINT "MemoryDb_apiUserId_fkey" FOREIGN KEY ("apiUserId") REFERENCES "ApiUser"("id") ON DELETE RESTRICT ON UPDATE CASCADE; |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -93,6 +93,11 @@ enum VectorDbProvider { | |||||
SUPABASE | ||||||
} | ||||||
|
||||||
enum MemoryDbProvider { | ||||||
MOTORHEAD | ||||||
REDIS | ||||||
} | ||||||
|
||||||
model ApiUser { | ||||||
id String @id @default(uuid()) | ||||||
token String? | ||||||
|
@@ -106,6 +111,7 @@ model ApiUser { | |||||
workflows Workflow[] | ||||||
vectorDb VectorDb[] | ||||||
workflowConfigs WorkflowConfig[] | ||||||
MemoryDb MemoryDb[] | ||||||
} | ||||||
|
||||||
model Agent { | ||||||
|
@@ -120,6 +126,7 @@ model Agent { | |||||
updatedAt DateTime @updatedAt | ||||||
llms AgentLLM[] | ||||||
llmModel LLMModel? @default(GPT_3_5_TURBO_16K_0613) | ||||||
memory MemoryDbProvider? @default(MOTORHEAD) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's use
Suggested change
|
||||||
prompt String? | ||||||
apiUserId String | ||||||
apiUser ApiUser @relation(fields: [apiUserId], references: [id]) | ||||||
|
@@ -253,3 +260,13 @@ model VectorDb { | |||||
apiUserId String | ||||||
apiUser ApiUser @relation(fields: [apiUserId], references: [id]) | ||||||
} | ||||||
|
||||||
model MemoryDb { | ||||||
id String @id @default(uuid()) | ||||||
provider MemoryDbProvider @default(MOTORHEAD) | ||||||
options Json | ||||||
createdAt DateTime @default(now()) | ||||||
updatedAt DateTime @updatedAt | ||||||
apiUserId String | ||||||
apiUser ApiUser @relation(fields: [apiUserId], references: [id]) | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
"use client" | ||
|
||
import * as React from "react" | ||
import Image from "next/image" | ||
import { useRouter } from "next/navigation" | ||
import { zodResolver } from "@hookform/resolvers/zod" | ||
import { useForm } from "react-hook-form" | ||
import * as z from "zod" | ||
|
||
import { siteConfig } from "@/config/site" | ||
import { Api } from "@/lib/api" | ||
import { Button } from "@/components/ui/button" | ||
import { | ||
Dialog, | ||
DialogClose, | ||
DialogContent, | ||
DialogDescription, | ||
DialogFooter, | ||
DialogHeader, | ||
DialogTitle, | ||
} from "@/components/ui/dialog" | ||
import { | ||
Form, | ||
FormControl, | ||
FormDescription, | ||
FormField, | ||
FormItem, | ||
FormLabel, | ||
FormMessage, | ||
} from "@/components/ui/form" | ||
import { Input } from "@/components/ui/input" | ||
import { Skeleton } from "@/components/ui/skeleton" | ||
import { Spinner } from "@/components/ui/spinner" | ||
|
||
const motorheadSchema = z.object({ | ||
MEMORY_API_URL: z.string(), | ||
}) | ||
|
||
const redisSchema = z.object({ | ||
REDIS_MEMORY_URL: z.string(), | ||
REDIS_MEMORY_WINDOW: z.string(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's use |
||
}) | ||
|
||
const formSchema = z.object({ | ||
options: z.union([motorheadSchema, redisSchema]), | ||
}) | ||
|
||
export default function Memory({ | ||
profile, | ||
configuredMemories, | ||
}: { | ||
profile: any | ||
configuredMemories: any | ||
}) { | ||
const [open, setOpen] = React.useState<boolean>() | ||
const [selectedDB, setSelectedDB] = React.useState<any>() | ||
const router = useRouter() | ||
const api = new Api(profile.api_key) | ||
const { ...form } = useForm<z.infer<typeof formSchema>>({ | ||
resolver: zodResolver(formSchema), | ||
defaultValues: { | ||
options: {}, | ||
}, | ||
}) | ||
|
||
async function onSubmit(values: z.infer<typeof formSchema>) { | ||
const payload = { | ||
...values, | ||
options: | ||
Object.keys(values.options).length === 0 ? undefined : values.options, | ||
} | ||
|
||
const isExistingConnection = configuredMemories.find( | ||
(db: any) => db.provider === selectedDB.provider | ||
) | ||
|
||
if (isExistingConnection) { | ||
await api.patchMemoryDb(isExistingConnection.id, { | ||
...payload, | ||
provider: selectedDB.provider, | ||
}) | ||
} else { | ||
await api.createMemoryDb({ ...payload, provider: selectedDB.provider }) | ||
} | ||
|
||
form.reset() | ||
router.refresh() | ||
setOpen(false) | ||
} | ||
|
||
return ( | ||
<div className="container flex max-w-4xl flex-col space-y-10 pt-10"> | ||
<div className="flex flex-col"> | ||
<p className="text-lg font-medium">Storage</p> | ||
<p className="text-muted-foreground"> | ||
Connect your vector database to store your embeddings in your own | ||
databases. | ||
</p> | ||
</div> | ||
<div className="flex-col border-b"> | ||
{siteConfig.memoryDbs.map((memoryDb) => { | ||
const isConfigured = configuredMemories.find( | ||
(db: any) => db.provider === memoryDb.provider | ||
) | ||
|
||
return ( | ||
<div | ||
className="flex items-center justify-between border-t py-4" | ||
key={memoryDb.provider} | ||
> | ||
<div className="flex items-center space-x-4"> | ||
{isConfigured ? ( | ||
<div className="h-2 w-2 rounded-full bg-green-400" /> | ||
) : ( | ||
<div className="h-2 w-2 rounded-full bg-muted" /> | ||
)} | ||
<div className="flex items-center space-x-3"> | ||
<Image | ||
src={memoryDb.logo} | ||
width="40" | ||
height="40" | ||
alt={memoryDb.name} | ||
/> | ||
<p className="font-medium">{memoryDb.name}</p> | ||
</div> | ||
</div> | ||
<Button | ||
variant="outline" | ||
size="sm" | ||
onClick={() => { | ||
setSelectedDB(memoryDb) | ||
setOpen(true) | ||
}} | ||
> | ||
Settings | ||
</Button> | ||
</div> | ||
) | ||
})} | ||
</div> | ||
<Dialog | ||
onOpenChange={(isOpen) => { | ||
if (!isOpen) { | ||
form.reset() | ||
} | ||
|
||
setOpen(isOpen) | ||
}} | ||
open={open} | ||
> | ||
<DialogContent> | ||
<DialogHeader> | ||
<DialogTitle>{selectedDB?.name}</DialogTitle> | ||
<DialogDescription> | ||
Connect your private {selectedDB?.name} account to Superagent. | ||
</DialogDescription> | ||
</DialogHeader> | ||
<div className="flex flex-col"> | ||
<Form {...form}> | ||
<form | ||
onSubmit={form.handleSubmit(onSubmit)} | ||
className="w-full space-y-4" | ||
> | ||
{selectedDB?.metadata.map((metadataField: any) => ( | ||
<FormField | ||
key={metadataField.key} | ||
control={form.control} | ||
// @ts-ignore | ||
name={`options.${metadataField.key}`} | ||
render={({ field }) => ( | ||
<FormItem> | ||
<FormLabel>{metadataField.label}</FormLabel> | ||
{metadataField.type === "input" && ( | ||
<FormControl> | ||
{/* @ts-ignore */} | ||
<Input | ||
{...field} | ||
placeholder={ | ||
"placeholder" in metadataField | ||
? metadataField.placeholder | ||
: "" | ||
} | ||
type="text" | ||
/> | ||
</FormControl> | ||
)} | ||
{"helpText" in metadataField && ( | ||
<FormDescription className="pb-2"> | ||
{metadataField.helpText as string} | ||
</FormDescription> | ||
)} | ||
<FormMessage /> | ||
</FormItem> | ||
)} | ||
/> | ||
))} | ||
<DialogFooter> | ||
<DialogClose asChild> | ||
<Button type="button" variant="ghost"> | ||
Close | ||
</Button> | ||
</DialogClose> | ||
<Button type="submit" size="sm"> | ||
{form.control._formState.isSubmitting ? ( | ||
<Spinner /> | ||
) : ( | ||
"Save configuration" | ||
)} | ||
</Button> | ||
</DialogFooter> | ||
</form> | ||
</Form> | ||
</div> | ||
</DialogContent> | ||
</Dialog> | ||
</div> | ||
) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[ZoneTransfer] | ||
ZoneId=3 | ||
ReferrerUrl=https://github.com/getmetal/motorhead | ||
HostUrl=https://avatars.githubusercontent.com/u/75705874?s=48&v=4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdqfork Why did you move this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separating code and environment dependencies allows for optimal utilization of the cache built by Docker. During the debugging phase, there are more changes in the code compared to changes in the environment dependencies.