From 6a2fbdbf1735b7ff580fa03e44e9eea28c9c51a4 Mon Sep 17 00:00:00 2001 From: Yi Lu Date: Wed, 12 Jul 2023 17:53:52 -0700 Subject: [PATCH] Fix tests (#86) --- autochain/memory/redis_memory.py | 4 ++-- autochain/tools/internal_search/pinecone_tool.py | 1 - overrides/main.html | 6 ------ test_utils/__init__.py | 0 {tests/common => test_utils}/pinecone_mocks.py | 3 ++- tests/agent/test_conversational_agent.py | 3 +++ tests/memory/test_long_term_memory.py | 3 ++- tests/tools/test_pinecone_tool.py | 2 +- 8 files changed, 10 insertions(+), 12 deletions(-) delete mode 100644 overrides/main.html create mode 100644 test_utils/__init__.py rename {tests/common => test_utils}/pinecone_mocks.py (97%) diff --git a/autochain/memory/redis_memory.py b/autochain/memory/redis_memory.py index 46a54fa..de090e0 100644 --- a/autochain/memory/redis_memory.py +++ b/autochain/memory/redis_memory.py @@ -1,5 +1,5 @@ import pickle -from typing import Any, Optional +from typing import Any, Optional, Dict from autochain.agent.message import ( ChatMessageHistory, @@ -39,7 +39,7 @@ def load_memory( return default return pickle.loads(pickled) - def load_conversation(self, **kwargs: dict[str, Any]) -> ChatMessageHistory: + def load_conversation(self, **kwargs: Dict[str, Any]) -> ChatMessageHistory: """Return chat message history.""" redis_key = self.redis_key_prefix + f":{ChatMessageHistory.__name__}" return ChatMessageHistory(messages=self.load_memory(redis_key, [])) diff --git a/autochain/tools/internal_search/pinecone_tool.py b/autochain/tools/internal_search/pinecone_tool.py index f98a97c..e1bab74 100644 --- a/autochain/tools/internal_search/pinecone_tool.py +++ b/autochain/tools/internal_search/pinecone_tool.py @@ -70,7 +70,6 @@ def _format_output(query_response: QueryResponse) -> str: response: QueryResponse = self.index.query( vector=encoding, top_k=top_k, include_values=include_values ) - return _format_output(response) def add_docs(self, docs: List[PineconeDoc], **kwargs): diff --git a/overrides/main.html b/overrides/main.html deleted file mode 100644 index 8ae9379..0000000 --- a/overrides/main.html +++ /dev/null @@ -1,6 +0,0 @@ -{% extends "base.html" %} - -{% block extrahead %} - -{{ super() }} -{% endblock %} \ No newline at end of file diff --git a/test_utils/__init__.py b/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/common/pinecone_mocks.py b/test_utils/pinecone_mocks.py similarity index 97% rename from tests/common/pinecone_mocks.py rename to test_utils/pinecone_mocks.py index 76ebd4b..a74a1b0 100644 --- a/tests/common/pinecone_mocks.py +++ b/test_utils/pinecone_mocks.py @@ -9,7 +9,8 @@ class MockIndex: - kv = {} + def __init__(self): + self.kv = {} def upsert(self, id_vectors, *args, **kwargs): for id, vector in id_vectors: diff --git a/tests/agent/test_conversational_agent.py b/tests/agent/test_conversational_agent.py index 737e087..979752b 100644 --- a/tests/agent/test_conversational_agent.py +++ b/tests/agent/test_conversational_agent.py @@ -1,4 +1,5 @@ import json +import os from unittest import mock import pytest @@ -84,6 +85,7 @@ def openai_response_fixture(): def test_should_answer_prompt(openai_should_answer_fixture): + os.environ["OPENAI_API_KEY"] = "mock_api_key" agent = ConversationalAgent.from_llm_and_tools(llm=ChatOpenAI(), tools=[]) inputs = {"history": "good user query"} @@ -97,6 +99,7 @@ def test_should_answer_prompt(openai_should_answer_fixture): def test_plan(openai_response_fixture): + os.environ["OPENAI_API_KEY"] = "mock_api_key" agent = ConversationalAgent.from_llm_and_tools( llm=ChatOpenAI(), tools=[HandOffToAgent()] ) diff --git a/tests/memory/test_long_term_memory.py b/tests/memory/test_long_term_memory.py index a55561b..ba836b8 100644 --- a/tests/memory/test_long_term_memory.py +++ b/tests/memory/test_long_term_memory.py @@ -2,8 +2,9 @@ from autochain.memory.long_term_memory import LongTermMemory from autochain.tools.internal_search.chromadb_tool import ChromaDoc, ChromaDBSearch from autochain.tools.internal_search.pinecone_tool import PineconeSearch, PineconeDoc -from tests.common.pinecone_mocks import ( +from test_utils.pinecone_mocks import ( DummyEncoder, + pinecone_index_fixture ) diff --git a/tests/tools/test_pinecone_tool.py b/tests/tools/test_pinecone_tool.py index 11a76a6..022d9a3 100644 --- a/tests/tools/test_pinecone_tool.py +++ b/tests/tools/test_pinecone_tool.py @@ -1,5 +1,5 @@ from autochain.tools.internal_search.pinecone_tool import PineconeSearch, PineconeDoc -from tests.common.pinecone_mocks import ( +from test_utils.pinecone_mocks import ( DummyEncoder, pinecone_index_fixture, )