diff --git a/autogpt_platform/backend/backend/server/v2/library/db.py b/autogpt_platform/backend/backend/server/v2/library/db.py index 50ff9544863f..911411352f8d 100644 --- a/autogpt_platform/backend/backend/server/v2/library/db.py +++ b/autogpt_platform/backend/backend/server/v2/library/db.py @@ -1,6 +1,5 @@ import json import logging -from typing import List import prisma.errors import prisma.models @@ -16,95 +15,150 @@ async def get_library_agents( - user_id: str, - search_query: str | None = None, -) -> List[backend.server.v2.library.model.LibraryAgent]: - """ - Returns all agents (AgentGraph) that belong to the user and all agents in their library (LibraryAgent table) - """ + user_id: str, search_query: str | None = None +) -> list[backend.server.v2.library.model.LibraryAgent]: logger.debug( - f"Getting library agents for user {user_id} with search query: {search_query}" + f"Fetching library agents for user_id={user_id} search_query={search_query}" ) - try: - # Sanitize and validate search query by escaping special characters - # Build where clause with sanitized inputs - where_clause = prisma.types.LibraryAgentWhereInput( - userId=user_id, - isDeleted=False, - isArchived=False, - **( - { - "OR": [ - { - "Agent": { - "name": { - "contains": search_query, - "mode": "insensitive", - } - } - }, - { - "Agent": { - "description": { - "contains": search_query, - "mode": "insensitive", - } - } - }, - ] - } - if search_query - else {} - ), + if search_query and len(search_query.strip()) > 100: + logger.warning(f"Search query too long: {search_query}") + raise backend.server.v2.store.exceptions.DatabaseError( + "Search query is too long." ) - # Get agents in user's library with nodes and links + where_clause = prisma.types.LibraryAgentWhereInput( + userId=user_id, + isDeleted=False, + isArchived=False, + ) + + if search_query: + where_clause["OR"] = [ + { + "Agent": { + "is": {"name": {"contains": search_query, "mode": "insensitive"}} + } + }, + { + "Agent": { + "is": { + "description": {"contains": search_query, "mode": "insensitive"} + } + } + }, + ] + + try: library_agents = await prisma.models.LibraryAgent.prisma().find_many( where=where_clause, include={ "Agent": { "include": { - "AgentNodes": { - "include": { - "Input": True, - "Output": True, - } - } + "AgentNodes": {"include": {"Input": True, "Output": True}} } - }, - "AgentPreset": {"include": {"InputPresets": True}}, + } }, order=[{"updatedAt": "desc"}], ) - - # Convert to Graph models first - graphs = [] - # Add library agents - for agent in library_agents: - if agent.Agent: - try: - graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent)) - except Exception as e: - logger.error(f"Error processing library agent {agent.agentId}: {e}") - continue - - result = [ + logger.debug(f"Retrieved {len(library_agents)} agents for user_id={user_id}.") + return [ backend.server.v2.library.model.LibraryAgent.from_db(agent) for agent in library_agents ] + except prisma.errors.PrismaError as e: + logger.error(f"Database error fetching library agents: {e}") + raise backend.server.v2.store.exceptions.DatabaseError( + "Unable to fetch library agents." + ) + + +async def create_library_agent( + agent_id: str, agent_version: int, user_id: str +) -> prisma.models.LibraryAgent: + """ + Adds an agent to the user's library (LibraryAgent table) + """ + + try: - logger.debug(f"Found {len(result)} library agents") - return result + library_agent = await prisma.models.LibraryAgent.prisma().create( + data=prisma.types.LibraryAgentCreateInput( + userId=user_id, + agentId=agent_id, + agentVersion=agent_version, + isCreatedByUser=False, + useGraphIsActiveVersion=True, + ) + ) + return library_agent + except prisma.errors.PrismaError as e: + logger.error(f"Database error creating agent to library: {str(e)}") + raise backend.server.v2.store.exceptions.DatabaseError( + "Failed to create agent to library" + ) from e + + +async def update_agent_version_in_library( + user_id: str, agent_id: str, agent_version: int +) -> None: + """ + Updates the agent version in the library + """ + try: + await prisma.models.LibraryAgent.prisma().update( + where={ + "userId": user_id, + "agentId": agent_id, + "useGraphIsActiveVersion": True, + }, + data=prisma.types.LibraryAgentUpdateInput( + Agent=prisma.types.AgentGraphUpdateOneWithoutRelationsInput( + connect=prisma.types.AgentGraphWhereUniqueInput( + id=agent_id, + version=agent_version, + ), + ), + ), + ) + except prisma.errors.PrismaError as e: + logger.error(f"Database error updating agent version in library: {str(e)}") + raise backend.server.v2.store.exceptions.DatabaseError( + "Failed to update agent version in library" + ) from e + +async def update_library_agent( + library_agent_id: str, + user_id: str, + auto_update_version: bool = False, + is_favorite: bool = False, + is_archived: bool = False, + is_deleted: bool = False, +) -> None: + """ + Updates the library agent with the given fields + """ + try: + await prisma.models.LibraryAgent.prisma().update( + where={"id": library_agent_id, "userId": user_id}, + data=prisma.types.LibraryAgentUpdateInput( + useGraphIsActiveVersion=auto_update_version, + isFavorite=is_favorite, + isArchived=is_archived, + isDeleted=is_deleted, + ), + ) except prisma.errors.PrismaError as e: - logger.error(f"Database error getting library agents: {str(e)}") + logger.error(f"Database error updating library agent: {str(e)}") raise backend.server.v2.store.exceptions.DatabaseError( - "Failed to fetch library agents" + "Failed to update library agent" ) from e -async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None: +async def add_store_agent_to_library( + store_listing_version_id: str, user_id: str +) -> None: """ Finds the agent from the store listing version and adds it to the user's library (LibraryAgent table) if they don't already have it