From ebb74edb7dab31e802f272319b94932c681531d6 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Mon, 9 Oct 2023 21:26:51 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- libs/legacy/app/api/agent_documents.py | 14 ++- libs/legacy/app/api/agent_tools.py | 14 ++- libs/legacy/app/api/agents.py | 6 +- libs/legacy/app/api/auth.py | 40 ++++----- libs/legacy/app/api/users.py | 5 +- libs/legacy/app/lib/agents/agent.py | 12 +-- libs/legacy/app/lib/agents/base.py | 28 +++--- libs/legacy/app/lib/agents/factory.py | 7 +- libs/legacy/app/lib/api_tokens.py | 3 +- libs/legacy/app/lib/auth/prisma.py | 69 ++++++--------- libs/legacy/app/lib/callbacks.py | 7 +- libs/legacy/app/lib/documents.py | 89 ++++++++++---------- libs/legacy/app/lib/loaders/sitemap.py | 4 +- libs/legacy/app/lib/parsers.py | 17 ++-- libs/legacy/app/lib/splitters.py | 15 ++-- libs/legacy/app/lib/tools.py | 24 ++---- libs/legacy/app/lib/vectorstores/pinecone.py | 26 +++--- libs/superagent/app/agents/base.py | 3 +- libs/superagent/app/tools/bing_search.py | 6 +- libs/superagent/app/tools/datasource.py | 15 ++-- libs/superagent/app/tools/e2b.py | 7 +- libs/superagent/app/tools/metaphor.py | 6 +- libs/superagent/app/tools/openapi.py | 6 +- libs/superagent/app/tools/pubmed.py | 6 +- libs/superagent/app/tools/replicate.py | 6 +- libs/superagent/app/tools/wolfram_alpha.py | 3 +- libs/superagent/app/tools/zapier.py | 6 +- libs/superagent/app/utils/api.py | 3 +- libs/superagent/app/vectorstores/pinecone.py | 26 +++--- 29 files changed, 186 insertions(+), 287 deletions(-) diff --git a/libs/legacy/app/api/agent_documents.py b/libs/legacy/app/api/agent_documents.py index 1000ffdce..fdf5fd2dd 100644 --- a/libs/legacy/app/api/agent_documents.py +++ b/libs/legacy/app/api/agent_documents.py @@ -17,15 +17,11 @@ def parse_filter_params(request: Request): query_params = request.query_params - filter_params = {} - - for k, v in query_params.items(): - if k.startswith("filter[") and k.endswith("]"): - # Removing 'filter[' from start and ']' from end - filter_key = k[7:-1] - filter_params[filter_key] = v - - return filter_params + return { + k[7:-1]: v + for k, v in query_params.items() + if k.startswith("filter[") and k.endswith("]") + } @router.post( diff --git a/libs/legacy/app/api/agent_tools.py b/libs/legacy/app/api/agent_tools.py index e6dde8aae..a7eb8c1c3 100644 --- a/libs/legacy/app/api/agent_tools.py +++ b/libs/legacy/app/api/agent_tools.py @@ -14,15 +14,11 @@ def parse_filter_params(request: Request): query_params = request.query_params - filter_params = {} - - for k, v in query_params.items(): - if k.startswith("filter[") and k.endswith("]"): - # Removing 'filter[' from start and ']' from end - filter_key = k[7:-1] - filter_params[filter_key] = v - - return filter_params + return { + k[7:-1]: v + for k, v in query_params.items() + if k.startswith("filter[") and k.endswith("]") + } @router.post( diff --git a/libs/legacy/app/api/agents.py b/libs/legacy/app/api/agents.py index 17a0f4cd3..5d87af076 100644 --- a/libs/legacy/app/api/agents.py +++ b/libs/legacy/app/api/agents.py @@ -119,9 +119,9 @@ async def read_library_agents(token=Depends(JWTBearer())): ) async def read_agent(agentId: str, token=Depends(JWTBearer())): """Agent detail endpoint""" - agent = prisma.agent.find_unique(where={"id": agentId}, include={"prompt": True}) - - if agent: + if agent := prisma.agent.find_unique( + where={"id": agentId}, include={"prompt": True} + ): return {"success": True, "data": agent} logger.error(f"Agent with id: {agentId} not found") diff --git a/libs/legacy/app/api/auth.py b/libs/legacy/app/api/auth.py index ab57a37c3..f6c794a56 100644 --- a/libs/legacy/app/api/auth.py +++ b/libs/legacy/app/api/auth.py @@ -19,13 +19,12 @@ @router.post("/auth/sign-in", response_model=SignInOutput) async def sign_in(signIn: SignIn): try: - user = prisma.user.find_first( + if user := prisma.user.find_first( where={ "email": signIn.email, }, include={"profile": True}, - ) - if user: + ): validated = validatePassword(signIn.password, user.password) del user.password @@ -34,16 +33,12 @@ async def sign_in(signIn: SignIn): return {"success": True, "data": SignInOut(token=token, user=user)} logger.warning("Invalid password") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials", - ) else: logger.warning("User not found") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials", - ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", + ) except Exception as e: logger.error("Couldn't find user by email", exc_info=e) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -96,16 +91,15 @@ async def oauth_handler(body: OAuth): ) if user: return {"success": True, "data": user} - else: - user = prisma.user.create( - { - "email": body.email, - "provider": body.provider, - "name": body.name, - "accessToken": body.access_token, - } - ) - prisma.profile.create({"userId": user.id}) + user = prisma.user.create( + { + "email": body.email, + "provider": body.provider, + "name": body.name, + "accessToken": body.access_token, + } + ) + prisma.profile.create({"userId": user.id}) - if user: - return {"success": True, "data": user} + if user: + return {"success": True, "data": user} diff --git a/libs/legacy/app/api/users.py b/libs/legacy/app/api/users.py index 2c14350aa..8237542e3 100644 --- a/libs/legacy/app/api/users.py +++ b/libs/legacy/app/api/users.py @@ -34,8 +34,9 @@ async def read_user_me(token=Depends(JWTBearer())): @router.get("/users/{userId}", response_model=UserOutput) async def read_user(userId: str): try: - user = prisma.user.find_unique(where={"id": userId}, include={"profile": True}) - if user: + if user := prisma.user.find_unique( + where={"id": userId}, include={"profile": True} + ): return {"success": True, "data": user} else: logger.error("Couldn't find user") diff --git a/libs/legacy/app/lib/agents/agent.py b/libs/legacy/app/lib/agents/agent.py index 81345fa91..9a2dca504 100644 --- a/libs/legacy/app/lib/agents/agent.py +++ b/libs/legacy/app/lib/agents/agent.py @@ -20,7 +20,7 @@ def get_agent(self, session: str | None = None) -> AgentExecutor: llm = self.agent_base._get_llm() memory = self.agent_base._get_memory(session) prompt = self.agent_base._get_prompt() - agent = LLMChain( + return LLMChain( llm=llm, memory=memory, verbose=True, @@ -28,8 +28,6 @@ def get_agent(self, session: str | None = None) -> AgentExecutor: output_key="output", ) - return agent - class OpenAIAgent(AgentStrategy): def __init__(self, agent_base: AgentBase): @@ -40,7 +38,7 @@ def get_agent(self, session: str | None = None) -> AgentExecutor: tools = self.agent_base._get_tools() memory = self.agent_base._get_memory(session) prompt = self.agent_base._get_prompt() - agent = initialize_agent( + return initialize_agent( tools=tools, llm=llm, agent=AgentType.OPENAI_FUNCTIONS, @@ -58,8 +56,6 @@ def get_agent(self, session: str | None = None) -> AgentExecutor: return_intermediate_steps=True, ) - return agent - class ReactAgent(AgentStrategy): def __init__(self, agent_base): @@ -79,12 +75,10 @@ def get_agent(self, session: str | None = None) -> AgentExecutor: stop=["\nObservation:"], allowed_tools=tool_names, ) - agent = AgentExecutor.from_agent_and_tools( + return AgentExecutor.from_agent_and_tools( agent=agent_config, tools=tools, verbose=True, memory=memory, return_intermediate_steps=True, ) - - return agent diff --git a/libs/legacy/app/lib/agents/base.py b/libs/legacy/app/lib/agents/base.py index 80f8727ea..df6e5b5ae 100644 --- a/libs/legacy/app/lib/agents/base.py +++ b/libs/legacy/app/lib/agents/base.py @@ -86,7 +86,7 @@ def __init__( self.tools = self._get_agent_tools() def _get_api_key(self) -> str: - if self.llm["provider"] == "openai-chat" or self.llm["provider"] == "openai": + if self.llm["provider"] in ["openai-chat", "openai"]: return ( self.llm["api_key"] if "api_key" in self.llm @@ -161,7 +161,7 @@ def _get_llm(self, has_streaming: bool = True) -> Any: ), ], ) - if self.has_streaming and has_streaming != False + if self.has_streaming and has_streaming else ChatOpenAI( model_name=self.llm["model"], openai_api_key=self._get_api_key(), @@ -189,7 +189,7 @@ def _get_llm(self, has_streaming: bool = True) -> Any: ) ], ) - if self.has_streaming and has_streaming != False + if self.has_streaming and has_streaming else ChatAnthropic(anthropic_api_key=self._get_api_key()) ) @@ -207,8 +207,10 @@ def _get_llm(self, has_streaming: bool = True) -> Any: ) ], ) - if self.has_streaming and has_streaming != False - else Cohere(cohere_api_key=self._get_api_key(), model=self.llm["model"]) + if self.has_streaming and has_streaming + else Cohere( + cohere_api_key=self._get_api_key(), model=self.llm["model"] + ) ) if self.llm["provider"] == "azure-openai": @@ -286,12 +288,10 @@ def _get_memory(self, session) -> ConversationBufferMemory: return ConversationBufferMemory(memory_key="chat_history", output_key="output") def _get_agent_documents(self) -> Any: - agent_documents = prisma.agentdocument.find_many( + return prisma.agentdocument.find_many( where={"agentId": self.id}, include={"document": True} ) - return agent_documents - def _get_tool_and_input_by_type( self, type: str, metadata: dict = None ) -> Tuple[Any, Any]: @@ -407,12 +407,10 @@ def _get_tools(self) -> list[Tool]: return tools def _get_agent_tools(self) -> Any: - tools = prisma.agenttool.find_many( + return prisma.agenttool.find_many( where={"agentId": self.id}, include={"tool": True} ) - return tools - def _format_trace(self, trace: Any) -> dict: if self.documents or self.tools: return json.dumps( @@ -463,7 +461,7 @@ def cache_message( metadata = { "agentId": agent_id, - "sessionId": str(session_id), + "sessionId": session_id, "text": query, "cached_message": ai_message, "type": "cache", @@ -476,13 +474,11 @@ def cache_message( def get_cached_result(self, query: str) -> str | None: vectorstore: PineconeVectorStore = VectorStoreBase().get_database() - results = vectorstore.query( + if results := vectorstore.query( prompt=query, metadata_filter={"agentId": self.id}, min_score=0.9, - ) - - if results: + ): timestamp: float = results[0].metadata.get("timestamp", 0.0) if timestamp and time.time() - timestamp > self.cache_ttl: diff --git a/libs/legacy/app/lib/agents/factory.py b/libs/legacy/app/lib/agents/factory.py index 519ab973c..171419bd5 100644 --- a/libs/legacy/app/lib/agents/factory.py +++ b/libs/legacy/app/lib/agents/factory.py @@ -9,13 +9,8 @@ def create_agent(agent_base: AgentBase): if agent_base.tools or agent_base.documents: return OpenAIAgent(agent_base) - return DefaultAgent(agent_base) - elif agent_base.type == "REACT": if agent_base.tools or agent_base.documents: return ReactAgent(agent_base) - return DefaultAgent(agent_base) - - else: - return DefaultAgent(agent_base) + return DefaultAgent(agent_base) diff --git a/libs/legacy/app/lib/api_tokens.py b/libs/legacy/app/lib/api_tokens.py index 3ba6c2789..c4fdfe5fa 100644 --- a/libs/legacy/app/lib/api_tokens.py +++ b/libs/legacy/app/lib/api_tokens.py @@ -2,5 +2,4 @@ def generate_api_token() -> str: - api_token = uuid.uuid4().hex - return api_token + return uuid.uuid4().hex diff --git a/libs/legacy/app/lib/auth/prisma.py b/libs/legacy/app/lib/auth/prisma.py index 0ff13da58..f3d8c2793 100644 --- a/libs/legacy/app/lib/auth/prisma.py +++ b/libs/legacy/app/lib/auth/prisma.py @@ -24,9 +24,7 @@ def signJWT(user_id: str) -> Dict[str, str]: "exp": EXPIRES, "userId": user_id, } - token = jwt.encode(payload, jwtSecret, algorithm="HS256") - - return token + return jwt.encode(payload, jwtSecret, algorithm="HS256") def decodeJWT(token: str) -> dict: @@ -59,47 +57,40 @@ async def __call__(self, request: Request): JWTBearer, self ).__call__(request) - if credentials: - if not credentials.scheme == "Bearer": - raise HTTPException( - status_code=403, detail="Invalid token or expired token." - ) - - if credentials.credentials.startswith("oauth_"): - accessToken = credentials.credentials.split("oauth_")[-1] - oauth_data = prisma.user.find_first(where={"accessToken": accessToken}) - self.validateOAuthData(oauth_data) - return dict({"userId": oauth_data.id}) - else: - if not self.verify_jwt(credentials.credentials): - tokens_data = prisma.apitoken.find_first( - where={"token": credentials.credentials} - ) - - if not tokens_data: - raise HTTPException( - status_code=403, detail="Invalid token or expired token." - ) + if not credentials: + raise HTTPException(status_code=403, detail="Invalid authorization code.") + if credentials.scheme != "Bearer": + raise HTTPException( + status_code=403, detail="Invalid token or expired token." + ) + if credentials.credentials.startswith("oauth_"): + accessToken = credentials.credentials.split("oauth_")[-1] + oauth_data = prisma.user.find_first(where={"accessToken": accessToken}) + self.validateOAuthData(oauth_data) + return dict({"userId": oauth_data.id}) + else: + if not self.verify_jwt(credentials.credentials): + if tokens_data := prisma.apitoken.find_first( + where={"token": credentials.credentials} + ): return json.loads(tokens_data.json()) - return decodeJWT(credentials.credentials) + else: + raise HTTPException( + status_code=403, detail="Invalid token or expired token." + ) - else: - raise HTTPException(status_code=403, detail="Invalid authorization code.") + return decodeJWT(credentials.credentials) def verify_jwt(self, jwtToken: str) -> bool: - isTokenValid: bool = False try: payload = decodeJWT(jwtToken) except Exception: payload = None - if payload: - isTokenValid = True - - return isTokenValid + return bool(payload) def validateOAuthData(self, oauth_data) -> bool: if oauth_data.provider == "google": @@ -114,19 +105,14 @@ def verify_github_token(self, accessToken: str) -> bool: uri = "https://api.github.com/user" headers = {"Authorization": f"token {accessToken}"} res = req.get(uri, headers=headers) - if res.status_code == 200: - return True - else: - return False + return res.status_code == 200 def verify_google_token(self, accessToken: str) -> bool: try: id_info = id_token.verify_oauth2_token( accessToken, requests.Request(), config("GOOGLE_CLIENT_ID") ) - if id_info["aud"] != config("GOOGLE_CLIENT_ID"): - return False - return True + return id_info["aud"] == config("GOOGLE_CLIENT_ID") except ValueError: return False @@ -136,9 +122,6 @@ def verify_azure_token(self, accessToken: str) -> bool: exclude_managed_identity_credential=False ) token = credentials.get_token("https://management.azure.com/.default") - if token.token == accessToken: - return True - else: - return False + return token.token == accessToken except Exception: return False diff --git a/libs/legacy/app/lib/callbacks.py b/libs/legacy/app/lib/callbacks.py index f4fb0d6f8..66c046efe 100644 --- a/libs/legacy/app/lib/callbacks.py +++ b/libs/legacy/app/lib/callbacks.py @@ -45,10 +45,9 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: if self.agent_type == "OPENAI": if gen.message.content != "": self.on_llm_end_() - else: - if "Final Answer" in gen.message.content: - self.seen_final_answer[0] = False - self.on_llm_end_() + elif "Final Answer" in gen.message.content: + self.seen_final_answer[0] = False + self.on_llm_end_() def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any diff --git a/libs/legacy/app/lib/documents.py b/libs/legacy/app/lib/documents.py index 2a2548496..42833ff2b 100644 --- a/libs/legacy/app/lib/documents.py +++ b/libs/legacy/app/lib/documents.py @@ -80,23 +80,32 @@ def load_documents(type, metadata, url, content, from_page, to_page): loader = RemoteDepthReader(depth=depth) return loader.load_langchain_documents(url=url) - if type == "TXT": - file_response = content - if content is None: - if url is None: - raise ValueError("URL must not be None when content is None.") - file_response = requests.get(url).text + if type == "GITHUB_REPOSITORY": + parsed_url = urlparse(url) + path_parts = parsed_url.path.split("/") # type: ignore + repo_name = path_parts[2] + + with tempfile.TemporaryDirectory() as temp_dir: + repo_path = f"{temp_dir}/{repo_name}/" # type: ignore + loader = GitLoader( + clone_url=url, + repo_path=repo_path, + branch=metadata["branch"], # type: ignore + ) + return loader.load_and_split() - if file_response is not None: - with NamedTemporaryFile(suffix=".txt", delete=True) as temp_file: - temp_file.write(file_response.encode()) - temp_file.flush() - loader = TextLoader(file_path=temp_file.name) - return loader.load() - else: + elif type == "MARKDOWN": + if url is None: + raise ValueError("URL must not be None for MARKDOWN type.") + if not (file_response := requests.get(url).text): raise ValueError("file_response must not be None.") - if type == "PDF": + with NamedTemporaryFile(suffix=".md", delete=True) as temp_file: + temp_file.write(file_response.encode()) + temp_file.flush() + loader = UnstructuredMarkdownLoader(file_path=temp_file.name) + return loader.load() + elif type == "PDF": if url is None: raise ValueError("URL must not be None for PDF type.") loader = CustomPDFPlumberLoader( @@ -104,48 +113,35 @@ def load_documents(type, metadata, url, content, from_page, to_page): ) return loader.load() - if type == "URL": + elif type == "TXT": + file_response = content + if file_response is None: + if url is None: + raise ValueError("URL must not be None when content is None.") + file_response = requests.get(url).text + + if file_response is None: + raise ValueError("file_response must not be None.") + + with NamedTemporaryFile(suffix=".txt", delete=True) as temp_file: + temp_file.write(file_response.encode()) + temp_file.flush() + loader = TextLoader(file_path=temp_file.name) + return loader.load() + elif type == "URL": if url is None: raise ValueError("URL must not be None for URL type.") url_list = url.split(",") loader = WebBaseLoader(url_list) return loader.load() - if type == "YOUTUBE": + elif type == "YOUTUBE": if url is None: raise ValueError("URL must not be None for YOUTUBE type.") video_id = url.split("youtube.com/watch?v=")[-1] loader = YoutubeLoader(video_id=video_id) return loader.load() - if type == "MARKDOWN": - if url is None: - raise ValueError("URL must not be None for MARKDOWN type.") - file_response = requests.get(url).text - - if file_response: - with NamedTemporaryFile(suffix=".md", delete=True) as temp_file: - temp_file.write(file_response.encode()) - temp_file.flush() - loader = UnstructuredMarkdownLoader(file_path=temp_file.name) - return loader.load() - else: - raise ValueError("file_response must not be None.") - - if type == "GITHUB_REPOSITORY": - parsed_url = urlparse(url) - path_parts = parsed_url.path.split("/") # type: ignore - repo_name = path_parts[2] - - with tempfile.TemporaryDirectory() as temp_dir: - repo_path = f"{temp_dir}/{repo_name}/" # type: ignore - loader = GitLoader( - clone_url=url, - repo_path=repo_path, - branch=metadata["branch"], # type: ignore - ) - return loader.load_and_split() - return [] except Exception as e: @@ -183,8 +179,9 @@ def upsert_document( logger.info( f"Upserting document for document_id: {document_id} of type: {type}" ) - documents = load_documents(type, metadata, url, content, from_page, to_page) - if documents: + if documents := load_documents( + type, metadata, url, content, from_page, to_page + ): embed_documents(documents, document_id, text_splitter) logger.info(f"Upsert for document_id: {document_id} completed.") except Exception as e: diff --git a/libs/legacy/app/lib/loaders/sitemap.py b/libs/legacy/app/lib/loaders/sitemap.py index 5d2f95c5b..78be40aed 100644 --- a/libs/legacy/app/lib/loaders/sitemap.py +++ b/libs/legacy/app/lib/loaders/sitemap.py @@ -24,9 +24,7 @@ def fetch_text(self, url): response.raise_for_status() # Raise exception for HTTP errors soup = BeautifulSoup(response.content, "html.parser") raw_text = soup.get_text(separator=" ").strip() - cleaned_text = re.sub(r"\s+", " ", raw_text) - - return cleaned_text + return re.sub(r"\s+", " ", raw_text) def matches_any_pattern(self, url): """Check if the URL matches any of the given patterns.""" diff --git a/libs/legacy/app/lib/parsers.py b/libs/legacy/app/lib/parsers.py index 6a555e5b8..a448ac788 100644 --- a/libs/legacy/app/lib/parsers.py +++ b/libs/legacy/app/lib/parsers.py @@ -91,16 +91,15 @@ def lazy_parse(self, blob: Blob) -> Iterator[Document]: if self.to_page is None: # by default, starts from 1 and processes the whole document doc = pdfplumber.open(file_path) + elif self.to_page > 0: + """Parse till the maximum page number provided""" + doc = pdfplumber.open( + file_path, pages=list(range(self.from_page, self.to_page)) + ) else: - if self.to_page > 0: - """Parse till the maximum page number provided""" - doc = pdfplumber.open( - file_path, pages=list(range(self.from_page, self.to_page)) - ) - else: - raise ValueError( - "Value of to_page should be greater than equal to 1." - ) + raise ValueError( + "Value of to_page should be greater than equal to 1." + ) yield from [ Document( diff --git a/libs/legacy/app/lib/splitters.py b/libs/legacy/app/lib/splitters.py index c0c8c75e4..9d642ee31 100644 --- a/libs/legacy/app/lib/splitters.py +++ b/libs/legacy/app/lib/splitters.py @@ -47,8 +47,7 @@ def character_splitter(self) -> list[Document]: text_splitter = CharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap ) - docs = text_splitter.split_documents(self.documents) - return docs + return text_splitter.split_documents(self.documents) def recursive_splitter(self) -> list[Document]: """ @@ -81,8 +80,7 @@ def token_splitter(self) -> list[Document]: chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap ) texts = text_splitter.split_text(self.documents) - docs = [Document(page_content=text) for text in texts] - return docs + return [Document(page_content=text) for text in texts] def spacy_splitter(self) -> list[Document]: """ @@ -91,8 +89,7 @@ def spacy_splitter(self) -> list[Document]: text_splitter = SpacyTextSplitter(chunk_size=self.chunk_size) texts = text_splitter.split_text(self.documents) - docs = [Document(page_content=text) for text in texts] - return docs + return [Document(page_content=text) for text in texts] def nltk_splitter(self) -> list[Document]: """ @@ -101,8 +98,7 @@ def nltk_splitter(self) -> list[Document]: text_splitter = NLTKTextSplitter(chunk_size=self.chunk_size) texts = text_splitter.split_text(self.documents) - docs = [Document(page_content=text) for text in texts] - return docs + return [Document(page_content=text) for text in texts] def huggingface_splitter(self) -> list[Document]: """ @@ -122,5 +118,4 @@ def huggingface_splitter(self) -> list[Document]: tokenizer, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap ) texts = text_splitter.split_text(self.documents) - docs = [Document(page_content=text) for text in texts] - return docs + return [Document(page_content=text) for text in texts] diff --git a/libs/legacy/app/lib/tools.py b/libs/legacy/app/lib/tools.py index 328b3333b..d93bcfe37 100644 --- a/libs/legacy/app/lib/tools.py +++ b/libs/legacy/app/lib/tools.py @@ -32,18 +32,14 @@ class ToolDescription(Enum): def get_search_tool() -> Any: - search = BingSearchAPIWrapper( + return BingSearchAPIWrapper( bing_search_url=config("BING_SEARCH_URL"), bing_subscription_key=config("BING_SUBSCRIPTION_KEY"), ) - return search - def get_wolfram_alpha_tool() -> Any: - wolfram = WolframAlphaAPIWrapper() - - return wolfram + return WolframAlphaAPIWrapper() def get_replicate_tool(metadata: dict) -> Any: @@ -62,15 +58,13 @@ def get_replicate_tool(metadata: dict) -> Any: def get_zapier_nla_tool(metadata: dict, llm: Any) -> Any: zapier = ZapierNLAWrapper(zapier_nla_api_key=metadata["zapier_nla_api_key"]) toolkit = ZapierToolkit.from_zapier_nla_wrapper(zapier) - agent = initialize_agent( + return initialize_agent( toolkit.get_tools(), llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, ) - return agent - def get_chatgpt_plugin_tool(metadata: dict) -> Any: plugin_url = metadata["chatgptPluginURL"] @@ -83,12 +77,10 @@ def get_chatgpt_plugin_tool(metadata: dict) -> Any: def get_openapi_tool(metadata: dict) -> Any: openapi_url = metadata.get("openApiUrl") headers = metadata.get("headers") - agent = get_openapi_chain( + return get_openapi_chain( spec=openapi_url, headers=json.loads(headers) if headers else None ) - return agent - class AgentTool: def __init__(self, metadata: dict, api_key: str) -> Any: @@ -144,13 +136,11 @@ def run(self, *args) -> str: documents = [Document(page_content=text) for text in document_text] chain = load_summarize_chain(self.llm, chain_type="stuff") - summary = chain.run( + return chain.run( input_documents=documents, question="Write a concise summary within 200 words.", ) - return summary - class DocumentTool: def __init__(self, document_id: str, query_type: str = "document"): @@ -159,7 +149,7 @@ def __init__(self, document_id: str, query_type: str = "document"): def run(self, prompt: str) -> str: """Use the tool.""" - results = ( + return ( VectorStoreBase() .get_database() .query_documents( @@ -169,5 +159,3 @@ def run(self, prompt: str) -> str: top_k=3, ) ) - - return results diff --git a/libs/legacy/app/lib/vectorstores/pinecone.py b/libs/legacy/app/lib/vectorstores/pinecone.py index 9306ce50f..74ead6ec1 100644 --- a/libs/legacy/app/lib/vectorstores/pinecone.py +++ b/libs/legacy/app/lib/vectorstores/pinecone.py @@ -85,8 +85,7 @@ def embed_documents(self, documents: list[Document], batch_size: int = 100): def batch_generator(chunks, batch_size): for i in range(0, len(chunks), batch_size): i_end = min(len(chunks), i + batch_size) - batch = chunks[i:i_end] - yield batch + yield chunks[i:i_end] batch_gen = batch_generator(chunks, batch_size) @@ -127,13 +126,11 @@ def _format_response(self, response: QueryResponse) -> list[Response]: *[self._extract_match_data(match) for match in response["matches"]] ) - responses = [ + return [ Response(id=id, text=text, metadata=meta) for id, text, meta in zip(ids, texts, metadata) ] - return responses - def query( self, prompt: str, @@ -164,8 +161,7 @@ def query( if match["score"] >= min_score ] - formatted_responses = self._format_response(raw_responses) - return formatted_responses + return self._format_response(raw_responses) def query_documents( self, @@ -215,20 +211,20 @@ def delete(self, document_id: str): include_values=False, ) - vector_ids = [match["id"] for match in documents_in_namespace["matches"]] - - if len(vector_ids) == 0: + if vector_ids := [ + match["id"] for match in documents_in_namespace["matches"] + ]: logger.info( - f"No vectors found in namespace `{document_id}`. " - f"Deleting `{document_id}` using default namespace." + f"Deleting {len(vector_ids)} documents in namespace {document_id}" ) - self.index.delete(filter={"document_id": document_id}, delete_all=False) + self.index.delete(ids=vector_ids, delete_all=False) else: logger.info( - f"Deleting {len(vector_ids)} documents in namespace {document_id}" + f"No vectors found in namespace `{document_id}`. " + f"Deleting `{document_id}` using default namespace." ) - self.index.delete(ids=vector_ids, delete_all=False) + self.index.delete(filter={"document_id": document_id}, delete_all=False) except Exception as e: logger.error(f"Failed to delete {document_id}. Error: {e}") diff --git a/libs/superagent/app/agents/base.py b/libs/superagent/app/agents/base.py index 598d54b3e..d5152cd0c 100644 --- a/libs/superagent/app/agents/base.py +++ b/libs/superagent/app/agents/base.py @@ -63,8 +63,7 @@ async def _get_tools( ) tools.append(tool) for agent_tool in agent_tools: - tool_info = TOOL_TYPE_MAPPING.get(agent_tool.tool.type) - if tool_info: + if tool_info := TOOL_TYPE_MAPPING.get(agent_tool.tool.type): tool = create_tool( tool_class=tool_info["class"], name=slugify(agent_tool.tool.name), diff --git a/libs/superagent/app/tools/bing_search.py b/libs/superagent/app/tools/bing_search.py index afcd5f77a..fe276ed75 100644 --- a/libs/superagent/app/tools/bing_search.py +++ b/libs/superagent/app/tools/bing_search.py @@ -16,8 +16,7 @@ def _run(self, search_query: str) -> str: bing_search_url=bing_search_url, bing_subscription_key=bing_subscription_key, ) - output = search.run(search_query) - return output + return search.run(search_query) async def _arun(self, search_query: str) -> str: bing_search_url = self.metadata["bingSearchUrl"] @@ -27,5 +26,4 @@ async def _arun(self, search_query: str) -> str: bing_subscription_key=bing_subscription_key, ) loop = asyncio.get_event_loop() - output = await loop.run_in_executor(None, search.run, search_query) - return output + return await loop.run_in_executor(None, search.run, search_query) diff --git a/libs/superagent/app/tools/datasource.py b/libs/superagent/app/tools/datasource.py index 2099f2ea2..e9f083c8a 100644 --- a/libs/superagent/app/tools/datasource.py +++ b/libs/superagent/app/tools/datasource.py @@ -75,13 +75,12 @@ def _run( ) -> str: """Use the tool.""" pinecone = PineconeVectorStore() - result = pinecone.query_documents( + return pinecone.query_documents( prompt=question, datasource_id=self.metadata["datasource_id"], query_type=self.metadata["query_type"], top_k=3, ) - return result async def _arun( self, @@ -89,13 +88,12 @@ async def _arun( ) -> str: """Use the tool asynchronously.""" pinecone = PineconeVectorStore() - result = pinecone.query_documents( + return pinecone.query_documents( prompt=question, datasource_id=self.metadata["datasource_id"], query_type=self.metadata["query_type"], top_k=3, ) - return result class StructuredDatasourceTool(BaseTool): @@ -120,8 +118,7 @@ def _load_csv_data(self, datasource: Datasource): file_content = StringIO(response.text) else: file_content = StringIO(datasource.content) - df = pd.read_csv(file_content) - return df + return pd.read_csv(file_content) def _run( self, @@ -146,8 +143,7 @@ def _run( verbose=True, agent_type=AgentType.OPENAI_FUNCTIONS, ) - output = agent.run(question) - return output + return agent.run(question) async def _arun( self, @@ -172,5 +168,4 @@ async def _arun( verbose=True, agent_type=AgentType.OPENAI_FUNCTIONS, ) - output = await agent.arun(question) - return output + return await agent.arun(question) diff --git a/libs/superagent/app/tools/e2b.py b/libs/superagent/app/tools/e2b.py index 95900b264..65b40a9b9 100644 --- a/libs/superagent/app/tools/e2b.py +++ b/libs/superagent/app/tools/e2b.py @@ -11,17 +11,14 @@ def _run(self, python_code: str) -> str: api_token = config("E2B_API_KEY") output, err = run_code_sync("Python3_DataAnalysis", python_code, api_token) - if err: - return "There was following error during execution: " + err - - return output + return f"There was following error during execution: {err}" if err else output async def _arun(self, python_code: str) -> str: api_token = config("E2B_API_KEY") output, err = await run_code("Python3_DataAnalysis", python_code, api_token) if err: - return "There was following error during execution: " + err + return f"There was following error during execution: {err}" print("output", output) return output diff --git a/libs/superagent/app/tools/metaphor.py b/libs/superagent/app/tools/metaphor.py index 761c10473..0d13bba78 100644 --- a/libs/superagent/app/tools/metaphor.py +++ b/libs/superagent/app/tools/metaphor.py @@ -11,12 +11,10 @@ def _run(self, search_query: str) -> str: search = MetaphorSearchAPIWrapper( metaphor_api_key=self.metadata["metaphorApiKey"] ) - output = search.results(search_query, 10, use_autoprompt=True) - return output + return search.results(search_query, 10, use_autoprompt=True) async def _arun(self, search_query: str) -> str: search = MetaphorSearchAPIWrapper( metaphor_api_key=self.metadata["metaphorApiKey"] ) - output = await search.results_async(search_query, 10, use_autoprompt=True) - return output + return await search.results_async(search_query, 10, use_autoprompt=True) diff --git a/libs/superagent/app/tools/openapi.py b/libs/superagent/app/tools/openapi.py index d70eb379d..d0ad39c2a 100644 --- a/libs/superagent/app/tools/openapi.py +++ b/libs/superagent/app/tools/openapi.py @@ -16,8 +16,7 @@ def _run(self, input: str) -> str: agent = get_openapi_chain( spec=openapi_url, headers=json.loads(headers) if headers else None ) - output = agent.run(input) - return output + return agent.run(input) async def _arun(self, input: str) -> str: openapi_url = self.metadata["openApiUrl"] @@ -26,5 +25,4 @@ async def _arun(self, input: str) -> str: spec=openapi_url, headers=json.loads(headers) if headers else None ) loop = asyncio.get_event_loop() - output = await loop.run_in_executor(None, agent.run, input) - return output + return await loop.run_in_executor(None, agent.run, input) diff --git a/libs/superagent/app/tools/pubmed.py b/libs/superagent/app/tools/pubmed.py index 4f302f796..338d589d5 100644 --- a/libs/superagent/app/tools/pubmed.py +++ b/libs/superagent/app/tools/pubmed.py @@ -10,11 +10,9 @@ class PubMed(BaseTool): def _run(self, search_query: str) -> str: pubmed = PubmedQueryRun(args_schema=self.args_schema) - output = pubmed.run(search_query) - return output + return pubmed.run(search_query) async def _arun(self, search_query: str) -> str: pubmed = PubmedQueryRun(args_schema=self.args_schema) loop = asyncio.get_event_loop() - output = await loop.run_in_executor(None, pubmed.run, search_query) - return output + return await loop.run_in_executor(None, pubmed.run, search_query) diff --git a/libs/superagent/app/tools/replicate.py b/libs/superagent/app/tools/replicate.py index f20d43700..160504e7f 100644 --- a/libs/superagent/app/tools/replicate.py +++ b/libs/superagent/app/tools/replicate.py @@ -14,8 +14,7 @@ def _run(self, prompt: str) -> str: model = ReplicateModel( model=model, input=input, api_token=api_token, replicate_api_token=api_token ) - output = model.predict(prompt) - return output + return model.predict(prompt) async def _arun(self, prompt: str) -> str: model = self.metadata["model"] @@ -24,5 +23,4 @@ async def _arun(self, prompt: str) -> str: model = ReplicateModel( model=model, input=input, api_token=api_token, replicate_api_token=api_token ) - output = await model.apredict(prompt) - return output + return await model.apredict(prompt) diff --git a/libs/superagent/app/tools/wolfram_alpha.py b/libs/superagent/app/tools/wolfram_alpha.py index 319bae12f..6a6b66158 100644 --- a/libs/superagent/app/tools/wolfram_alpha.py +++ b/libs/superagent/app/tools/wolfram_alpha.py @@ -18,5 +18,4 @@ async def _arun(self, input: str) -> str: app_id = self.metadata["appId"] wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=app_id) loop = asyncio.get_event_loop() - output = await loop.run_in_executor(None, wolfram.run, input) - return output + return await loop.run_in_executor(None, wolfram.run, input) diff --git a/libs/superagent/app/tools/zapier.py b/libs/superagent/app/tools/zapier.py index 3a0ca35a7..f8f356362 100644 --- a/libs/superagent/app/tools/zapier.py +++ b/libs/superagent/app/tools/zapier.py @@ -22,8 +22,7 @@ def _run(self, input: str) -> str: agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, ) - output = agent.run(input) - return output + return agent.run(input) async def _arun(self, input: str) -> str: zapier_nla_api_key = self.metadata["zapierNlaApiKey"] @@ -35,5 +34,4 @@ async def _arun(self, input: str) -> str: agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, ) - output = await agent.arun(input) - return output + return await agent.arun(input) diff --git a/libs/superagent/app/utils/api.py b/libs/superagent/app/utils/api.py index 524fced6e..755e770ad 100644 --- a/libs/superagent/app/utils/api.py +++ b/libs/superagent/app/utils/api.py @@ -19,8 +19,7 @@ def handle_exception(e): def generate_jwt(data: dict): - token = jwt.encode({**data}, config("JWT_SECRET"), algorithm="HS256") - return token + return jwt.encode({**data}, config("JWT_SECRET"), algorithm="HS256") def decode_jwt(token: str): diff --git a/libs/superagent/app/vectorstores/pinecone.py b/libs/superagent/app/vectorstores/pinecone.py index 691ca8437..f8512c294 100644 --- a/libs/superagent/app/vectorstores/pinecone.py +++ b/libs/superagent/app/vectorstores/pinecone.py @@ -85,8 +85,7 @@ def embed_documents(self, documents: list[Document], batch_size: int = 100): def batch_generator(chunks, batch_size): for i in range(0, len(chunks), batch_size): i_end = min(len(chunks), i + batch_size) - batch = chunks[i:i_end] - yield batch + yield chunks[i:i_end] batch_gen = batch_generator(chunks, batch_size) @@ -127,13 +126,11 @@ def _format_response(self, response: QueryResponse) -> list[Response]: *[self._extract_match_data(match) for match in response["matches"]] ) - responses = [ + return [ Response(id=id, text=text, metadata=meta) for id, text, meta in zip(ids, texts, metadata) ] - return responses - def query( self, prompt: str, @@ -164,8 +161,7 @@ def query( if match["score"] >= min_score ] - formatted_responses = self._format_response(raw_responses) - return formatted_responses + return self._format_response(raw_responses) def query_documents( self, @@ -215,9 +211,15 @@ def delete(self, datasource_id: str): include_values=False, ) - vector_ids = [match["id"] for match in documents_in_namespace["matches"]] + if vector_ids := [ + match["id"] for match in documents_in_namespace["matches"] + ]: + logger.info( + f"Deleting {len(vector_ids)} documents in namespace {datasource_id}" + ) + self.index.delete(ids=vector_ids, delete_all=False) - if len(vector_ids) == 0: + else: logger.info( f"No vectors found in namespace `{datasource_id}`. " f"Deleting `{datasource_id}` using default namespace." @@ -226,12 +228,6 @@ def delete(self, datasource_id: str): filter={"datasource_id": datasource_id}, delete_all=False ) - else: - logger.info( - f"Deleting {len(vector_ids)} documents in namespace {datasource_id}" - ) - self.index.delete(ids=vector_ids, delete_all=False) - except Exception as e: logger.error(f"Failed to delete {datasource_id}. Error: {e}")